Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 69 additions & 3 deletions toolium/test/utils/ai_utils/test_text_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_assert_text_similarity_with_default_method(similarity_mock):
input_text = 'Today it will be sunny'
expected_text = 'Today is sunny'
assert_text_similarity(input_text, expected_text, threshold=0.8)
similarity_mock.assert_called_once_with(input_text, expected_text)
similarity_mock.assert_called_once_with(input_text, expected_text, None)


@pytest.mark.skip(reason='Sentence Transformers model is not available in the CI environment')
Expand All @@ -157,7 +157,7 @@ def test_assert_text_similarity_with_configured_method(similarity_mock):
input_text = 'Today it will be sunny'
expected_text = 'Today is sunny'
assert_text_similarity(input_text, expected_text, threshold=0.8)
similarity_mock.assert_called_once_with(input_text, expected_text)
similarity_mock.assert_called_once_with(input_text, expected_text, None)


@mock.patch('toolium.utils.ai_utils.text_similarity.get_text_similarity_with_spacy')
Expand All @@ -173,4 +173,70 @@ def test_assert_text_similarity_with_configured_and_explicit_method(similarity_m
input_text = 'Today it will be sunny'
expected_text = 'Today is sunny'
assert_text_similarity(input_text, expected_text, threshold=0.8, similarity_method='spacy')
similarity_mock.assert_called_once_with(input_text, expected_text)
similarity_mock.assert_called_once_with(input_text, expected_text, None)


@mock.patch('toolium.utils.ai_utils.text_similarity.get_text_similarity_with_spacy')
def test_assert_text_similarity_with_configured_and_explicit_model(similarity_mock):
config = DriverWrappersPool.get_default_wrapper().config
try:
config.add_section('AI')
except Exception:
pass
config.set('AI', 'text_similarity_method', 'spacy')
similarity_mock.return_value = 0.9

input_text = 'Today it will be sunny'
expected_text = 'Today is sunny'
assert_text_similarity(input_text, expected_text, threshold=0.8, model_name='en_core_web_lg')
similarity_mock.assert_called_once_with(input_text, expected_text, 'en_core_web_lg')


@mock.patch('toolium.utils.ai_utils.text_similarity.get_text_similarity_with_spacy')
def test_assert_text_similarity_with_configured_and_explicit_method_and_model(similarity_mock):
config = DriverWrappersPool.get_default_wrapper().config
try:
config.add_section('AI')
except Exception:
pass
config.set('AI', 'text_similarity_method', 'sentence_transformers')
similarity_mock.return_value = 0.9

input_text = 'Today it will be sunny'
expected_text = 'Today is sunny'
assert_text_similarity(input_text, expected_text, threshold=0.8, similarity_method='spacy',
model_name='en_core_web_lg')
similarity_mock.assert_called_once_with(input_text, expected_text, 'en_core_web_lg')


@mock.patch('toolium.utils.ai_utils.text_similarity.get_text_similarity_with_openai')
def test_assert_text_similarity_with_explicit_openai(similarity_mock):
config = DriverWrappersPool.get_default_wrapper().config
try:
config.add_section('AI')
except Exception:
pass
config.set('AI', 'spacy_model', 'en_core_web_md')
similarity_mock.return_value = 0.9

input_text = 'Today it will be sunny'
expected_text = 'Today is sunny'
assert_text_similarity(input_text, expected_text, threshold=0.8, similarity_method='openai',
azure=True, model_name='gpt-4o-mini')
similarity_mock.assert_called_once_with(input_text, expected_text, 'gpt-4o-mini', azure=True)


@mock.patch('toolium.utils.ai_utils.text_similarity.get_text_similarity_with_openai')
def test_azure_openai_request_params(similarity_mock):
config = DriverWrappersPool.get_default_wrapper().config
try:
config.add_section('AI')
except Exception:
pass
config.set('AI', 'text_similarity_method', 'azure_openai')
similarity_mock.return_value = 0.9

input_text = 'Today it will be sunny'
expected_text = 'Today is sunny'
assert_text_similarity(input_text, expected_text, threshold=0.8)
similarity_mock.assert_called_once_with(input_text, expected_text, None, azure=True)
7 changes: 4 additions & 3 deletions toolium/utils/ai_utils/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,23 @@
logger = logging.getLogger(__name__)


def openai_request(system_message, user_message, model_name=None, azure=False):
def openai_request(system_message, user_message, model_name=None, azure=False, **kwargs):
"""
Make a request to OpenAI API (Azure or standard)

:param system_message: system message to set the behavior of the assistant
:param user_message: user message with the request
:param model: model to use
:param model_name: name of the model to use
:param azure: whether to use Azure OpenAI or standard OpenAI
:param kwargs: additional parameters to be passed to the OpenAI client (azure_endpoint, timeout, etc.)
:returns: response from OpenAI
"""
if OpenAI is None:
raise ImportError("OpenAI is not installed. Please run 'pip install toolium[ai]' to use OpenAI features")
config = DriverWrappersPool.get_default_wrapper().config
model_name = model_name or config.get_optional('AI', 'openai_model', 'gpt-4o-mini')
logger.info(f"Calling to OpenAI API with model {model_name}")
client = AzureOpenAI() if azure else OpenAI()
client = AzureOpenAI(**kwargs) if azure else OpenAI(**kwargs)
completion = client.chat.completions.create(
model=model_name,
messages=[
Expand Down
5 changes: 3 additions & 2 deletions toolium/utils/ai_utils/spacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,18 @@


@lru_cache(maxsize=8)
def get_spacy_model(model_name):
def get_spacy_model(model_name, **kwargs):
"""
get spaCy model.
This method uses lru cache to get spaCy model to improve performance.

:param model_name: spaCy model name
:param kwargs: additional parameters to be used by spaCy (disable, exclude, etc.)
:return: spaCy model
"""
if spacy is None:
return None
return spacy.load(model_name)
return spacy.load(model_name, **kwargs)


def is_negator(tok):
Expand Down
29 changes: 19 additions & 10 deletions toolium/utils/ai_utils/text_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@
logger = logging.getLogger(__name__)


def get_text_similarity_with_spacy(text, expected_text, model_name=None):
def get_text_similarity_with_spacy(text, expected_text, model_name=None, **kwargs):
"""
Return similarity between two texts using spaCy.
This method normalize both texts before comparing them.

:param text: string to compare
:param expected_text: string with the expected text
:param model_name: name of the spaCy model to use
:param kwargs: additional parameters to be used by spaCy (disable, exclude, etc.)
:returns: similarity score between the two texts
"""
# NOTE: spaCy similarity performance can be enhanced using some strategies like:
Expand All @@ -49,7 +50,7 @@ def get_text_similarity_with_spacy(text, expected_text, model_name=None):
# - Preprocessing texts. Now we only preprocess negations.
config = DriverWrappersPool.get_default_wrapper().config
model_name = model_name or config.get_optional('AI', 'spacy_model', 'en_core_web_md')
model = get_spacy_model(model_name)
model = get_spacy_model(model_name, **kwargs)
if model is None:
raise ImportError("spaCy is not installed. Please run 'pip install toolium[ai]' to use spaCy features")
text = model(preprocess_with_ud_negation(text, model))
Expand All @@ -59,35 +60,38 @@ def get_text_similarity_with_spacy(text, expected_text, model_name=None):
return similarity


def get_text_similarity_with_sentence_transformers(text, expected_text, model_name=None):
def get_text_similarity_with_sentence_transformers(text, expected_text, model_name=None, **kwargs):
"""
Return similarity between two texts using Sentence Transformers

:param text: string to compare
:param expected_text: string with the expected text
:param model_name: name of the Sentence Transformers model to use
:param kwargs: additional parameters to be used by SentenceTransformer (modules, device, prompts, etc.)
:returns: similarity score between the two texts
"""
if SentenceTransformer is None:
raise ImportError("Sentence Transformers is not installed. Please run 'pip install toolium[ai]'"
" to use Sentence Transformers features")
config = DriverWrappersPool.get_default_wrapper().config
model_name = model_name or config.get_optional('AI', 'sentence_transformers_model', 'all-mpnet-base-v2')
model = SentenceTransformer(model_name)
model = SentenceTransformer(model_name, **kwargs)
similarity = float(model.similarity(model.encode(expected_text), model.encode(text)))
# similarity can be slightly > 1 due to float precision
similarity = 1 if similarity > 1 else similarity
logger.info(f"Sentence Transformers similarity: {similarity} between '{text}' and '{expected_text}'")
return similarity


def get_text_similarity_with_openai(text, expected_text, azure=False):
def get_text_similarity_with_openai(text, expected_text, model_name=None, azure=False, **kwargs):
"""
Return semantic similarity between two texts using OpenAI LLM

:param text: string to compare
:param expected_text: string with the expected text
:param model_name: name of the OpenAI model to use
:param azure: whether to use Azure OpenAI or standard OpenAI
:param kwargs: additional parameters to be used by OpenAI client
:returns: tuple with similarity score between the two texts and explanation
"""
system_message = (
Expand All @@ -102,7 +106,7 @@ def get_text_similarity_with_openai(text, expected_text, azure=False):
f"The expected answer is: {expected_text}."
f" The LLM answer is: {text}."
)
response = openai_request(system_message, user_message, azure=azure)
response = openai_request(system_message, user_message, model_name, azure, **kwargs)
try:
response = json.loads(response)
similarity = float(response['similarity'])
Expand All @@ -114,18 +118,20 @@ def get_text_similarity_with_openai(text, expected_text, azure=False):
return similarity


def get_text_similarity_with_azure_openai(text, expected_text):
def get_text_similarity_with_azure_openai(text, expected_text, model_name=None, **kwargs):
"""
Return semantic similarity between two texts using Azure OpenAI LLM

:param text: string to compare
:param expected_text: string with the expected text
:param model_name: name of the Azure OpenAI model to use
:param kwargs: additional parameters to be used by Azure OpenAI client
:returns: tuple with similarity score between the two texts and explanation
"""
return get_text_similarity_with_openai(text, expected_text, azure=True)
return get_text_similarity_with_openai(text, expected_text, model_name, azure=True, **kwargs)


def assert_text_similarity(text, expected_texts, threshold, similarity_method=None):
def assert_text_similarity(text, expected_texts, threshold, similarity_method=None, model_name=None, **kwargs):
"""
Get similarity between one text and a list of expected texts and assert if any of the expected texts is similar.

Expand All @@ -134,14 +140,17 @@ def assert_text_similarity(text, expected_texts, threshold, similarity_method=No
:param threshold: minimum similarity score to consider texts similar
:param similarity_method: method to use for text comparison ('spacy', 'sentence_transformers', 'openai'
or 'azure_openai')
:param model_name: model name to use for the similarity method
:param kwargs: additional parameters to be used by comparison methods
"""
config = DriverWrappersPool.get_default_wrapper().config
similarity_method = similarity_method or config.get_optional('AI', 'text_similarity_method', 'spacy')
expected_texts = [expected_texts] if isinstance(expected_texts, str) else expected_texts
error_message = ""
for expected_text in expected_texts:
try:
similarity = globals()[f'get_text_similarity_with_{similarity_method}'](text, expected_text)
similarity = globals()[f'get_text_similarity_with_{similarity_method}'](text, expected_text,
model_name, **kwargs)
except KeyError:
raise ValueError(f"Unknown similarity_method: '{similarity_method}', please use 'spacy',"
f" 'sentence_transformers', 'openai' or 'azure_openai'")
Expand Down