Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate plugin.handle_text_embedding hook #2804

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 20 additions & 2 deletions autogpt/memory/vector/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import suppress
from typing import Any, overload

import numpy as np
Expand All @@ -12,12 +13,12 @@


@overload
def get_embedding(input: str | TText) -> Embedding:
def get_embedding(input: str | TText, config: Config) -> Embedding:
...


@overload
def get_embedding(input: list[str] | list[TText]) -> list[Embedding]:
def get_embedding(input: list[str] | list[TText], config: Config) -> list[Embedding]:
...


Expand All @@ -37,9 +38,16 @@

if isinstance(input, str):
input = input.replace("\n", " ")

with suppress(NotImplementedError):
return _get_embedding_with_plugin(input, config)

elif multiple and isinstance(input[0], str):
input = [text.replace("\n", " ") for text in input]

with suppress(NotImplementedError):
return [_get_embedding_with_plugin(i, config) for i in input]

model = config.embedding_model
kwargs = {"model": model}
kwargs.update(config.get_openai_credentials(model))
Expand All @@ -62,3 +70,13 @@

embeddings = sorted(embeddings, key=lambda x: x["index"])
return [d["embedding"] for d in embeddings]


def _get_embedding_with_plugin(text: str, config: Config) -> Embedding:
for plugin in config.plugins:
if plugin.can_handle_text_embedding(text):
embedding = plugin.handle_text_embedding(text)

Check warning on line 78 in autogpt/memory/vector/utils.py

View check run for this annotation

Codecov / codecov/patch

autogpt/memory/vector/utils.py#L78

Added line #L78 was not covered by tests
if embedding is not None:
return embedding

Check warning on line 80 in autogpt/memory/vector/utils.py

View check run for this annotation

Codecov / codecov/patch

autogpt/memory/vector/utils.py#L80

Added line #L80 was not covered by tests

raise NotImplementedError
12 changes: 7 additions & 5 deletions autogpt/models/base_open_ai_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,18 +198,20 @@ def handle_chat_completion(
def can_handle_text_embedding(self, text: str) -> bool:
"""This method is called to check that the plugin can
handle the text_embedding method.

Args:
text (str): The text to be convert to embedding.
Returns:
bool: True if the plugin can handle the text_embedding method."""
Returns:
bool: True if the plugin can handle the text_embedding method."""
return False

def handle_text_embedding(self, text: str) -> list:
"""This method is called when the chat completion is done.
def handle_text_embedding(self, text: str) -> list[float]:
"""This method is called to create a text embedding.

Args:
text (str): The text to be convert to embedding.
Returns:
list: The text embedding.
list[float]: The created embedding vector.
"""

def can_handle_user_input(self, user_input: str) -> bool:
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/models/test_base_open_api_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_dummy_plugin_default_methods(dummy_plugin):
assert not dummy_plugin.can_handle_pre_command()
assert not dummy_plugin.can_handle_post_command()
assert not dummy_plugin.can_handle_chat_completion(None, None, None, None)
assert not dummy_plugin.can_handle_text_embedding(None)

assert dummy_plugin.on_response("hello") == "hello"
assert dummy_plugin.post_prompt(None) is None
Expand All @@ -77,3 +78,4 @@ def test_dummy_plugin_default_methods(dummy_plugin):
assert isinstance(post_command, str)
assert post_command == "upgraded successfully!"
assert dummy_plugin.handle_chat_completion(None, None, None, None) is None
assert dummy_plugin.handle_text_embedding(None) is None