diff --git a/toolium/test/utils/ai_utils/test_text_similarity.py b/toolium/test/utils/ai_utils/test_text_similarity.py index 3b28feb9..322dd3cc 100644 --- a/toolium/test/utils/ai_utils/test_text_similarity.py +++ b/toolium/test/utils/ai_utils/test_text_similarity.py @@ -140,7 +140,7 @@ def test_assert_text_similarity_with_default_method(similarity_mock): input_text = 'Today it will be sunny' expected_text = 'Today is sunny' assert_text_similarity(input_text, expected_text, threshold=0.8) - similarity_mock.assert_called_once_with(input_text, expected_text) + similarity_mock.assert_called_once_with(input_text, expected_text, None) @pytest.mark.skip(reason='Sentence Transformers model is not available in the CI environment') @@ -157,7 +157,7 @@ def test_assert_text_similarity_with_configured_method(similarity_mock): input_text = 'Today it will be sunny' expected_text = 'Today is sunny' assert_text_similarity(input_text, expected_text, threshold=0.8) - similarity_mock.assert_called_once_with(input_text, expected_text) + similarity_mock.assert_called_once_with(input_text, expected_text, None) @mock.patch('toolium.utils.ai_utils.text_similarity.get_text_similarity_with_spacy') @@ -173,4 +173,70 @@ def test_assert_text_similarity_with_configured_and_explicit_method(similarity_m input_text = 'Today it will be sunny' expected_text = 'Today is sunny' assert_text_similarity(input_text, expected_text, threshold=0.8, similarity_method='spacy') - similarity_mock.assert_called_once_with(input_text, expected_text) + similarity_mock.assert_called_once_with(input_text, expected_text, None) + + +@mock.patch('toolium.utils.ai_utils.text_similarity.get_text_similarity_with_spacy') +def test_assert_text_similarity_with_configured_and_explicit_model(similarity_mock): + config = DriverWrappersPool.get_default_wrapper().config + try: + config.add_section('AI') + except Exception: + pass + config.set('AI', 'text_similarity_method', 'spacy') + similarity_mock.return_value = 0.9 + + input_text = 'Today it will be sunny' + expected_text = 'Today is sunny' + assert_text_similarity(input_text, expected_text, threshold=0.8, model_name='en_core_web_lg') + similarity_mock.assert_called_once_with(input_text, expected_text, 'en_core_web_lg') + + +@mock.patch('toolium.utils.ai_utils.text_similarity.get_text_similarity_with_spacy') +def test_assert_text_similarity_with_configured_and_explicit_method_and_model(similarity_mock): + config = DriverWrappersPool.get_default_wrapper().config + try: + config.add_section('AI') + except Exception: + pass + config.set('AI', 'text_similarity_method', 'sentence_transformers') + similarity_mock.return_value = 0.9 + + input_text = 'Today it will be sunny' + expected_text = 'Today is sunny' + assert_text_similarity(input_text, expected_text, threshold=0.8, similarity_method='spacy', + model_name='en_core_web_lg') + similarity_mock.assert_called_once_with(input_text, expected_text, 'en_core_web_lg') + + +@mock.patch('toolium.utils.ai_utils.text_similarity.get_text_similarity_with_openai') +def test_assert_text_similarity_with_explicit_openai(similarity_mock): + config = DriverWrappersPool.get_default_wrapper().config + try: + config.add_section('AI') + except Exception: + pass + config.set('AI', 'spacy_model', 'en_core_web_md') + similarity_mock.return_value = 0.9 + + input_text = 'Today it will be sunny' + expected_text = 'Today is sunny' + assert_text_similarity(input_text, expected_text, threshold=0.8, similarity_method='openai', + azure=True, model_name='gpt-4o-mini') + similarity_mock.assert_called_once_with(input_text, expected_text, 'gpt-4o-mini', azure=True) + + +@mock.patch('toolium.utils.ai_utils.text_similarity.get_text_similarity_with_openai') +def test_azure_openai_request_params(similarity_mock): + config = DriverWrappersPool.get_default_wrapper().config + try: + config.add_section('AI') + except Exception: + pass + config.set('AI', 'text_similarity_method', 'azure_openai') + similarity_mock.return_value = 0.9 + + input_text = 'Today it will be sunny' + expected_text = 'Today is sunny' + assert_text_similarity(input_text, expected_text, threshold=0.8) + similarity_mock.assert_called_once_with(input_text, expected_text, None, azure=True) diff --git a/toolium/utils/ai_utils/openai.py b/toolium/utils/ai_utils/openai.py index 3cef27ac..3a1d6072 100644 --- a/toolium/utils/ai_utils/openai.py +++ b/toolium/utils/ai_utils/openai.py @@ -32,14 +32,15 @@ logger = logging.getLogger(__name__) -def openai_request(system_message, user_message, model_name=None, azure=False): +def openai_request(system_message, user_message, model_name=None, azure=False, **kwargs): """ Make a request to OpenAI API (Azure or standard) :param system_message: system message to set the behavior of the assistant :param user_message: user message with the request - :param model: model to use + :param model_name: name of the model to use :param azure: whether to use Azure OpenAI or standard OpenAI + :param kwargs: additional parameters to be passed to the OpenAI client (azure_endpoint, timeout, etc.) :returns: response from OpenAI """ if OpenAI is None: @@ -47,7 +48,7 @@ def openai_request(system_message, user_message, model_name=None, azure=False): config = DriverWrappersPool.get_default_wrapper().config model_name = model_name or config.get_optional('AI', 'openai_model', 'gpt-4o-mini') logger.info(f"Calling to OpenAI API with model {model_name}") - client = AzureOpenAI() if azure else OpenAI() + client = AzureOpenAI(**kwargs) if azure else OpenAI(**kwargs) completion = client.chat.completions.create( model=model_name, messages=[ diff --git a/toolium/utils/ai_utils/spacy.py b/toolium/utils/ai_utils/spacy.py index a1abc420..d3ff5614 100644 --- a/toolium/utils/ai_utils/spacy.py +++ b/toolium/utils/ai_utils/spacy.py @@ -31,17 +31,18 @@ @lru_cache(maxsize=8) -def get_spacy_model(model_name): +def get_spacy_model(model_name, **kwargs): """ get spaCy model. This method uses lru cache to get spaCy model to improve performance. :param model_name: spaCy model name + :param kwargs: additional parameters to be used by spaCy (disable, exclude, etc.) :return: spaCy model """ if spacy is None: return None - return spacy.load(model_name) + return spacy.load(model_name, **kwargs) def is_negator(tok): diff --git a/toolium/utils/ai_utils/text_similarity.py b/toolium/utils/ai_utils/text_similarity.py index 11b33ccd..6de68119 100644 --- a/toolium/utils/ai_utils/text_similarity.py +++ b/toolium/utils/ai_utils/text_similarity.py @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) -def get_text_similarity_with_spacy(text, expected_text, model_name=None): +def get_text_similarity_with_spacy(text, expected_text, model_name=None, **kwargs): """ Return similarity between two texts using spaCy. This method normalize both texts before comparing them. @@ -41,6 +41,7 @@ def get_text_similarity_with_spacy(text, expected_text, model_name=None): :param text: string to compare :param expected_text: string with the expected text :param model_name: name of the spaCy model to use + :param kwargs: additional parameters to be used by spaCy (disable, exclude, etc.) :returns: similarity score between the two texts """ # NOTE: spaCy similarity performance can be enhanced using some strategies like: @@ -49,7 +50,7 @@ def get_text_similarity_with_spacy(text, expected_text, model_name=None): # - Preprocessing texts. Now we only preprocess negations. config = DriverWrappersPool.get_default_wrapper().config model_name = model_name or config.get_optional('AI', 'spacy_model', 'en_core_web_md') - model = get_spacy_model(model_name) + model = get_spacy_model(model_name, **kwargs) if model is None: raise ImportError("spaCy is not installed. Please run 'pip install toolium[ai]' to use spaCy features") text = model(preprocess_with_ud_negation(text, model)) @@ -59,13 +60,14 @@ def get_text_similarity_with_spacy(text, expected_text, model_name=None): return similarity -def get_text_similarity_with_sentence_transformers(text, expected_text, model_name=None): +def get_text_similarity_with_sentence_transformers(text, expected_text, model_name=None, **kwargs): """ Return similarity between two texts using Sentence Transformers :param text: string to compare :param expected_text: string with the expected text :param model_name: name of the Sentence Transformers model to use + :param kwargs: additional parameters to be used by SentenceTransformer (modules, device, prompts, etc.) :returns: similarity score between the two texts """ if SentenceTransformer is None: @@ -73,7 +75,7 @@ def get_text_similarity_with_sentence_transformers(text, expected_text, model_na " to use Sentence Transformers features") config = DriverWrappersPool.get_default_wrapper().config model_name = model_name or config.get_optional('AI', 'sentence_transformers_model', 'all-mpnet-base-v2') - model = SentenceTransformer(model_name) + model = SentenceTransformer(model_name, **kwargs) similarity = float(model.similarity(model.encode(expected_text), model.encode(text))) # similarity can be slightly > 1 due to float precision similarity = 1 if similarity > 1 else similarity @@ -81,13 +83,15 @@ def get_text_similarity_with_sentence_transformers(text, expected_text, model_na return similarity -def get_text_similarity_with_openai(text, expected_text, azure=False): +def get_text_similarity_with_openai(text, expected_text, model_name=None, azure=False, **kwargs): """ Return semantic similarity between two texts using OpenAI LLM :param text: string to compare :param expected_text: string with the expected text + :param model_name: name of the OpenAI model to use :param azure: whether to use Azure OpenAI or standard OpenAI + :param kwargs: additional parameters to be used by OpenAI client :returns: tuple with similarity score between the two texts and explanation """ system_message = ( @@ -102,7 +106,7 @@ def get_text_similarity_with_openai(text, expected_text, azure=False): f"The expected answer is: {expected_text}." f" The LLM answer is: {text}." ) - response = openai_request(system_message, user_message, azure=azure) + response = openai_request(system_message, user_message, model_name, azure, **kwargs) try: response = json.loads(response) similarity = float(response['similarity']) @@ -114,18 +118,20 @@ def get_text_similarity_with_openai(text, expected_text, azure=False): return similarity -def get_text_similarity_with_azure_openai(text, expected_text): +def get_text_similarity_with_azure_openai(text, expected_text, model_name=None, **kwargs): """ Return semantic similarity between two texts using Azure OpenAI LLM :param text: string to compare :param expected_text: string with the expected text + :param model_name: name of the Azure OpenAI model to use + :param kwargs: additional parameters to be used by Azure OpenAI client :returns: tuple with similarity score between the two texts and explanation """ - return get_text_similarity_with_openai(text, expected_text, azure=True) + return get_text_similarity_with_openai(text, expected_text, model_name, azure=True, **kwargs) -def assert_text_similarity(text, expected_texts, threshold, similarity_method=None): +def assert_text_similarity(text, expected_texts, threshold, similarity_method=None, model_name=None, **kwargs): """ Get similarity between one text and a list of expected texts and assert if any of the expected texts is similar. @@ -134,6 +140,8 @@ def assert_text_similarity(text, expected_texts, threshold, similarity_method=No :param threshold: minimum similarity score to consider texts similar :param similarity_method: method to use for text comparison ('spacy', 'sentence_transformers', 'openai' or 'azure_openai') + :param model_name: model name to use for the similarity method + :param kwargs: additional parameters to be used by comparison methods """ config = DriverWrappersPool.get_default_wrapper().config similarity_method = similarity_method or config.get_optional('AI', 'text_similarity_method', 'spacy') @@ -141,7 +149,8 @@ def assert_text_similarity(text, expected_texts, threshold, similarity_method=No error_message = "" for expected_text in expected_texts: try: - similarity = globals()[f'get_text_similarity_with_{similarity_method}'](text, expected_text) + similarity = globals()[f'get_text_similarity_with_{similarity_method}'](text, expected_text, + model_name, **kwargs) except KeyError: raise ValueError(f"Unknown similarity_method: '{similarity_method}', please use 'spacy'," f" 'sentence_transformers', 'openai' or 'azure_openai'")