diff --git a/autogpt/memory/vector/utils.py b/autogpt/memory/vector/utils.py index eb69125666a..1b050d56259 100644 --- a/autogpt/memory/vector/utils.py +++ b/autogpt/memory/vector/utils.py @@ -1,3 +1,4 @@ +from contextlib import suppress from typing import Any, overload import numpy as np @@ -12,12 +13,12 @@ @overload -def get_embedding(input: str | TText) -> Embedding: +def get_embedding(input: str | TText, config: Config) -> Embedding: ... @overload -def get_embedding(input: list[str] | list[TText]) -> list[Embedding]: +def get_embedding(input: list[str] | list[TText], config: Config) -> list[Embedding]: ... @@ -37,9 +38,16 @@ def get_embedding( if isinstance(input, str): input = input.replace("\n", " ") + + with suppress(NotImplementedError): + return _get_embedding_with_plugin(input, config) + elif multiple and isinstance(input[0], str): input = [text.replace("\n", " ") for text in input] + with suppress(NotImplementedError): + return [_get_embedding_with_plugin(i, config) for i in input] + model = config.embedding_model kwargs = {"model": model} kwargs.update(config.get_openai_credentials(model)) @@ -62,3 +70,13 @@ def get_embedding( embeddings = sorted(embeddings, key=lambda x: x["index"]) return [d["embedding"] for d in embeddings] + + +def _get_embedding_with_plugin(text: str, config: Config) -> Embedding: + for plugin in config.plugins: + if plugin.can_handle_text_embedding(text): + embedding = plugin.handle_text_embedding(text) + if embedding is not None: + return embedding + + raise NotImplementedError diff --git a/autogpt/models/base_open_ai_plugin.py b/autogpt/models/base_open_ai_plugin.py index c0aac8ed2e5..60f6f91bf9d 100644 --- a/autogpt/models/base_open_ai_plugin.py +++ b/autogpt/models/base_open_ai_plugin.py @@ -198,18 +198,20 @@ def handle_chat_completion( def can_handle_text_embedding(self, text: str) -> bool: """This method is called to check that the plugin can handle the text_embedding method. + Args: text (str): The text to be convert to embedding. - Returns: - bool: True if the plugin can handle the text_embedding method.""" + Returns: + bool: True if the plugin can handle the text_embedding method.""" return False - def handle_text_embedding(self, text: str) -> list: - """This method is called when the chat completion is done. + def handle_text_embedding(self, text: str) -> list[float]: + """This method is called to create a text embedding. + Args: text (str): The text to be convert to embedding. Returns: - list: The text embedding. + list[float]: The created embedding vector. """ def can_handle_user_input(self, user_input: str) -> bool: diff --git a/tests/unit/models/test_base_open_api_plugin.py b/tests/unit/models/test_base_open_api_plugin.py index 4d41eddd377..e656f464350 100644 --- a/tests/unit/models/test_base_open_api_plugin.py +++ b/tests/unit/models/test_base_open_api_plugin.py @@ -54,6 +54,7 @@ def test_dummy_plugin_default_methods(dummy_plugin): assert not dummy_plugin.can_handle_pre_command() assert not dummy_plugin.can_handle_post_command() assert not dummy_plugin.can_handle_chat_completion(None, None, None, None) + assert not dummy_plugin.can_handle_text_embedding(None) assert dummy_plugin.on_response("hello") == "hello" assert dummy_plugin.post_prompt(None) is None @@ -77,3 +78,4 @@ def test_dummy_plugin_default_methods(dummy_plugin): assert isinstance(post_command, str) assert post_command == "upgraded successfully!" assert dummy_plugin.handle_chat_completion(None, None, None, None) is None + assert dummy_plugin.handle_text_embedding(None) is None