diff --git a/jigsawstack/__init__.py b/jigsawstack/__init__.py index 93ef980..7931505 100644 --- a/jigsawstack/__init__.py +++ b/jigsawstack/__init__.py @@ -13,6 +13,7 @@ from .summary import Summary, AsyncSummary from .geo import Geo, AsyncGeo from .prompt_engine import PromptEngine, AsyncPromptEngine +from .embedding import Embedding, AsyncEmbedding from .exceptions import JigsawStackError @@ -110,6 +111,11 @@ def __init__( api_url=api_url, disable_request_logging=disable_request_logging, ) + self.embedding = Embedding( + api_key=api_key, + api_url=api_url, + disable_request_logging=disable_request_logging, + ).execute class AsyncJigsawStack: @@ -215,6 +221,11 @@ def __init__( api_url=api_url, disable_request_logging=disable_request_logging, ) + self.embedding = AsyncEmbedding( + api_key=api_key, + api_url=api_url, + disable_request_logging=disable_request_logging, + ).execute # Create a global instance of the Web class diff --git a/jigsawstack/embedding.py b/jigsawstack/embedding.py new file mode 100644 index 0000000..2628c97 --- /dev/null +++ b/jigsawstack/embedding.py @@ -0,0 +1,77 @@ +from typing import Any, Dict, List, Union, cast, Literal +from typing_extensions import NotRequired, TypedDict +from .request import Request, RequestConfig +from .async_request import AsyncRequest +from typing import List, Union +from ._config import ClientConfig + + +class EmbeddingParams(TypedDict): + text: NotRequired[str] + file_content: NotRequired[Any] + type: Literal["text", "text-other", "image", "audio", "pdf"] + url: NotRequired[str] + file_store_key: NotRequired[str] + token_overflow_mode: NotRequired[Literal["truncate", "chunk", "error"]] = "chunk" + + +class EmbeddingResponse(TypedDict): + success: bool + embeddings: List[List[float]] + chunks: List[str] + + +class Embedding(ClientConfig): + + config: RequestConfig + + def __init__( + self, + api_key: str, + api_url: str, + disable_request_logging: Union[bool, None] = False, + ): + super().__init__(api_key, api_url, disable_request_logging) + self.config = RequestConfig( + api_url=api_url, + api_key=api_key, + disable_request_logging=disable_request_logging, + ) + + def execute(self, params: EmbeddingParams) -> EmbeddingResponse: + path = "/embedding" + resp = Request( + config=self.config, + path=path, + params=cast(Dict[Any, Any], params), + verb="post", + ).perform_with_content() + return resp + + +class AsyncEmbedding(ClientConfig): + + config: RequestConfig + + def __init__( + self, + api_key: str, + api_url: str, + disable_request_logging: Union[bool, None] = False, + ): + super().__init__(api_key, api_url, disable_request_logging) + self.config = RequestConfig( + api_url=api_url, + api_key=api_key, + disable_request_logging=disable_request_logging, + ) + + async def execute(self, params: EmbeddingParams) -> EmbeddingResponse: + path = "/embedding" + resp = await AsyncRequest( + config=self.config, + path=path, + params=cast(Dict[Any, Any], params), + verb="post", + ).perform_with_content() + return resp diff --git a/setup.py b/setup.py index c4def33..50189cc 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="jigsawstack", - version="0.1.24", + version="0.1.25", description="JigsawStack Python SDK", long_description=open("README.md", encoding="utf8").read(), long_description_content_type="text/markdown", diff --git a/tests/test_embedding_async.py b/tests/test_embedding_async.py new file mode 100644 index 0000000..bf2e1e6 --- /dev/null +++ b/tests/test_embedding_async.py @@ -0,0 +1,23 @@ +from unittest.mock import MagicMock +import unittest +from jigsawstack.exceptions import JigsawStackError +from jigsawstack import AsyncJigsawStack +import pytest +import asyncio +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_async_embedding_generation_response(): + async def _test(): + client = AsyncJigsawStack() + try: + result = await client.embedding({"text": "Hello, World!", "type": "text"}) + logger.info(result) + assert result["success"] == True + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError: {e}") + + asyncio.run(_test())