From 69599da10b0a69a9caa53d037f4472c422685cbd Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sat, 29 Apr 2023 09:09:49 +0800 Subject: [PATCH 1/4] add feature custom text embedding in plugin --- autogpt/llm/llm_utils.py | 6 ++++- autogpt/models/base_open_ai_plugin.py | 22 +++++++++++++++++++ .../unit/models/test_base_open_api_plugin.py | 2 ++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/autogpt/llm/llm_utils.py b/autogpt/llm/llm_utils.py index 9a2400c79478..a4cd9bc24c43 100644 --- a/autogpt/llm/llm_utils.py +++ b/autogpt/llm/llm_utils.py @@ -222,7 +222,11 @@ def get_ada_embedding(text: str) -> List[float]: cfg = Config() model = "text-embedding-ada-002" text = text.replace("\n", " ") - + for plugin in cfg.plugins: + if plugin.can_handle_text_embedding(text): + embedding = plugin.handle_text_embedding(text) + if embedding is not None: + return embedding if cfg.use_azure: kwargs = {"engine": cfg.get_azure_deployment_id_for_model(model)} else: diff --git a/autogpt/models/base_open_ai_plugin.py b/autogpt/models/base_open_ai_plugin.py index 046295c0dcbe..f9c94ebe76a4 100644 --- a/autogpt/models/base_open_ai_plugin.py +++ b/autogpt/models/base_open_ai_plugin.py @@ -197,3 +197,25 @@ def handle_chat_completion( str: The resulting response. """ pass + + def can_handle_text_embedding( + self, text: str + ) -> bool: + """This method is called to check that the plugin can + handle the text_embedding method. + Args: + text (str): The text to be convert to embedding. + Returns: + bool: True if the plugin can handle the text_embedding method.""" + return False + + def handle_text_embedding( + self, text: str + ) -> list: + """This method is called when the chat completion is done. + Args: + text (str): The text to be convert to embedding. + Returns: + list: The text embedding. + """ + pass diff --git a/tests/unit/models/test_base_open_api_plugin.py b/tests/unit/models/test_base_open_api_plugin.py index 456c74c762af..32ad15574bc7 100644 --- a/tests/unit/models/test_base_open_api_plugin.py +++ b/tests/unit/models/test_base_open_api_plugin.py @@ -62,6 +62,7 @@ def test_dummy_plugin_default_methods(dummy_plugin): assert not dummy_plugin.can_handle_pre_command() assert not dummy_plugin.can_handle_post_command() assert not dummy_plugin.can_handle_chat_completion(None, None, None, None) + assert not dummy_plugin.can_handle_text_embedding(None) assert dummy_plugin.on_response("hello") == "hello" assert dummy_plugin.post_prompt(None) is None @@ -85,3 +86,4 @@ def test_dummy_plugin_default_methods(dummy_plugin): assert isinstance(post_command, str) assert post_command == "upgraded successfully!" assert dummy_plugin.handle_chat_completion(None, None, None, None) is None + assert dummy_plugin.handle_text_embedding(None) is None From 86bd9ff7c544e8d0cd665575dc8056ac210e29df Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sat, 29 Apr 2023 09:26:49 +0800 Subject: [PATCH 2/4] black code format --- autogpt/models/base_open_ai_plugin.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/autogpt/models/base_open_ai_plugin.py b/autogpt/models/base_open_ai_plugin.py index f9c94ebe76a4..811ecbdf3c4e 100644 --- a/autogpt/models/base_open_ai_plugin.py +++ b/autogpt/models/base_open_ai_plugin.py @@ -198,9 +198,7 @@ def handle_chat_completion( """ pass - def can_handle_text_embedding( - self, text: str - ) -> bool: + def can_handle_text_embedding(self, text: str) -> bool: """This method is called to check that the plugin can handle the text_embedding method. Args: @@ -209,9 +207,7 @@ def can_handle_text_embedding( bool: True if the plugin can handle the text_embedding method.""" return False - def handle_text_embedding( - self, text: str - ) -> list: + def handle_text_embedding(self, text: str) -> list: """This method is called when the chat completion is done. Args: text (str): The text to be convert to embedding. From ff51c5a5a67095f717376faa39065755a529964f Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Fri, 14 Jul 2023 22:39:50 +0200 Subject: [PATCH 3/4] _get_embedding_with_plugin() --- autogpt/memory/vector/utils.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/autogpt/memory/vector/utils.py b/autogpt/memory/vector/utils.py index eb69125666aa..1b050d562596 100644 --- a/autogpt/memory/vector/utils.py +++ b/autogpt/memory/vector/utils.py @@ -1,3 +1,4 @@ +from contextlib import suppress from typing import Any, overload import numpy as np @@ -12,12 +13,12 @@ @overload -def get_embedding(input: str | TText) -> Embedding: +def get_embedding(input: str | TText, config: Config) -> Embedding: ... @overload -def get_embedding(input: list[str] | list[TText]) -> list[Embedding]: +def get_embedding(input: list[str] | list[TText], config: Config) -> list[Embedding]: ... @@ -37,9 +38,16 @@ def get_embedding( if isinstance(input, str): input = input.replace("\n", " ") + + with suppress(NotImplementedError): + return _get_embedding_with_plugin(input, config) + elif multiple and isinstance(input[0], str): input = [text.replace("\n", " ") for text in input] + with suppress(NotImplementedError): + return [_get_embedding_with_plugin(i, config) for i in input] + model = config.embedding_model kwargs = {"model": model} kwargs.update(config.get_openai_credentials(model)) @@ -62,3 +70,13 @@ def get_embedding( embeddings = sorted(embeddings, key=lambda x: x["index"]) return [d["embedding"] for d in embeddings] + + +def _get_embedding_with_plugin(text: str, config: Config) -> Embedding: + for plugin in config.plugins: + if plugin.can_handle_text_embedding(text): + embedding = plugin.handle_text_embedding(text) + if embedding is not None: + return embedding + + raise NotImplementedError From 818f381c8edd2492743ad3bcadb6d9e58d1714d1 Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Fri, 14 Jul 2023 22:46:11 +0200 Subject: [PATCH 4/4] Fix docstring & type hint --- autogpt/models/base_open_ai_plugin.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/autogpt/models/base_open_ai_plugin.py b/autogpt/models/base_open_ai_plugin.py index c0aac8ed2e57..60f6f91bf9dd 100644 --- a/autogpt/models/base_open_ai_plugin.py +++ b/autogpt/models/base_open_ai_plugin.py @@ -198,18 +198,20 @@ def handle_chat_completion( def can_handle_text_embedding(self, text: str) -> bool: """This method is called to check that the plugin can handle the text_embedding method. + Args: text (str): The text to be convert to embedding. - Returns: - bool: True if the plugin can handle the text_embedding method.""" + Returns: + bool: True if the plugin can handle the text_embedding method.""" return False - def handle_text_embedding(self, text: str) -> list: - """This method is called when the chat completion is done. + def handle_text_embedding(self, text: str) -> list[float]: + """This method is called to create a text embedding. + Args: text (str): The text to be convert to embedding. Returns: - list: The text embedding. + list[float]: The created embedding vector. """ def can_handle_user_input(self, user_input: str) -> bool: