From b84878ad55c54b54054e6f21adf588e3138cd43e Mon Sep 17 00:00:00 2001 From: Narcisse Date: Sun, 5 Jan 2025 15:39:36 +0100 Subject: [PATCH 1/3] clean up --- jigsawstack/embedding_api.py | 80 ++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 jigsawstack/embedding_api.py diff --git a/jigsawstack/embedding_api.py b/jigsawstack/embedding_api.py new file mode 100644 index 0000000..faacaee --- /dev/null +++ b/jigsawstack/embedding_api.py @@ -0,0 +1,80 @@ +from typing import Any, Dict, List, Union, cast, Literal +from typing_extensions import NotRequired, TypedDict +from .request import Request, RequestConfig +from .async_request import AsyncRequest +from typing import List, Union +from ._config import ClientConfig + + +class EmbeddingParams(TypedDict): + text: Union[str, List[str]] + """ + The text to summarize. + """ + + type: NotRequired[Literal["text", "points"]] + + """ + The summary result type. Supported values are: text, points + """ + url: NotRequired[str] + file_store_key: NotRequired[str] + max_points: NotRequired[int] + max_characters: NotRequired[int] + + +class EmbeddingResponse(TypedDict): + success: bool + """ + Indicates whether the translation was successful. + """ + summary: str + """ + The summarized text. + """ + + +class Embedding(ClientConfig): + + config: RequestConfig + + def __init__( + self, + api_key: str, + api_url: str, + disable_request_logging: Union[bool, None] = False, + ): + super().__init__(api_key, api_url, disable_request_logging) + self.config = RequestConfig( + api_url=api_url, + api_key=api_key, + disable_request_logging=disable_request_logging, + ) + + +class AsyncEmbedding(ClientConfig): + + config: RequestConfig + + def __init__( + self, + api_key: str, + api_url: str, + disable_request_logging: Union[bool, None] = False, + ): + super().__init__(api_key, api_url, disable_request_logging) + self.config = RequestConfig( + api_url=api_url, + api_key=api_key, + disable_request_logging=disable_request_logging, + ) + + async def execute(self, params: EmbeddingParams) -> EmbeddingResponse: + path = "/ai/embedding" + resp = await AsyncRequest( + config=self.config, + path=path, + params=cast(Dict[Any, Any], params), + verb="post", + ).perform_with_content() + return resp From b2a110d2c4bcd3e455db40243414ae9b2193f9d2 Mon Sep 17 00:00:00 2001 From: Narcisse Date: Mon, 6 Jan 2025 12:24:25 +0100 Subject: [PATCH 2/3] update embedding --- jigsawstack/__init__.py | 11 ++++++ .../{embedding_api.py => embedding.py} | 37 +++++++++---------- setup.py | 2 +- 3 files changed, 29 insertions(+), 21 deletions(-) rename jigsawstack/{embedding_api.py => embedding.py} (72%) diff --git a/jigsawstack/__init__.py b/jigsawstack/__init__.py index 93ef980..7931505 100644 --- a/jigsawstack/__init__.py +++ b/jigsawstack/__init__.py @@ -13,6 +13,7 @@ from .summary import Summary, AsyncSummary from .geo import Geo, AsyncGeo from .prompt_engine import PromptEngine, AsyncPromptEngine +from .embedding import Embedding, AsyncEmbedding from .exceptions import JigsawStackError @@ -110,6 +111,11 @@ def __init__( api_url=api_url, disable_request_logging=disable_request_logging, ) + self.embedding = Embedding( + api_key=api_key, + api_url=api_url, + disable_request_logging=disable_request_logging, + ).execute class AsyncJigsawStack: @@ -215,6 +221,11 @@ def __init__( api_url=api_url, disable_request_logging=disable_request_logging, ) + self.embedding = AsyncEmbedding( + api_key=api_key, + api_url=api_url, + disable_request_logging=disable_request_logging, + ).execute # Create a global instance of the Web class diff --git a/jigsawstack/embedding_api.py b/jigsawstack/embedding.py similarity index 72% rename from jigsawstack/embedding_api.py rename to jigsawstack/embedding.py index faacaee..2628c97 100644 --- a/jigsawstack/embedding_api.py +++ b/jigsawstack/embedding.py @@ -7,31 +7,18 @@ class EmbeddingParams(TypedDict): - text: Union[str, List[str]] - """ - The text to summarize. - """ - - type: NotRequired[Literal["text", "points"]] - - """ - The summary result type. Supported values are: text, points - """ + text: NotRequired[str] + file_content: NotRequired[Any] + type: Literal["text", "text-other", "image", "audio", "pdf"] url: NotRequired[str] file_store_key: NotRequired[str] - max_points: NotRequired[int] - max_characters: NotRequired[int] + token_overflow_mode: NotRequired[Literal["truncate", "chunk", "error"]] = "chunk" class EmbeddingResponse(TypedDict): success: bool - """ - Indicates whether the translation was successful. - """ - summary: str - """ - The summarized text. - """ + embeddings: List[List[float]] + chunks: List[str] class Embedding(ClientConfig): @@ -51,6 +38,16 @@ def __init__( disable_request_logging=disable_request_logging, ) + def execute(self, params: EmbeddingParams) -> EmbeddingResponse: + path = "/embedding" + resp = Request( + config=self.config, + path=path, + params=cast(Dict[Any, Any], params), + verb="post", + ).perform_with_content() + return resp + class AsyncEmbedding(ClientConfig): @@ -70,7 +67,7 @@ def __init__( ) async def execute(self, params: EmbeddingParams) -> EmbeddingResponse: - path = "/ai/embedding" + path = "/embedding" resp = await AsyncRequest( config=self.config, path=path, diff --git a/setup.py b/setup.py index c4def33..50189cc 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="jigsawstack", - version="0.1.24", + version="0.1.25", description="JigsawStack Python SDK", long_description=open("README.md", encoding="utf8").read(), long_description_content_type="text/markdown", From 5d58d0e2aa3658703d6afe005299f6717607b7c7 Mon Sep 17 00:00:00 2001 From: Narcisse Date: Mon, 6 Jan 2025 12:29:01 +0100 Subject: [PATCH 3/3] clean up --- tests/test_embedding_async.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 tests/test_embedding_async.py diff --git a/tests/test_embedding_async.py b/tests/test_embedding_async.py new file mode 100644 index 0000000..bf2e1e6 --- /dev/null +++ b/tests/test_embedding_async.py @@ -0,0 +1,23 @@ +from unittest.mock import MagicMock +import unittest +from jigsawstack.exceptions import JigsawStackError +from jigsawstack import AsyncJigsawStack +import pytest +import asyncio +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_async_embedding_generation_response(): + async def _test(): + client = AsyncJigsawStack() + try: + result = await client.embedding({"text": "Hello, World!", "type": "text"}) + logger.info(result) + assert result["success"] == True + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError: {e}") + + asyncio.run(_test())