diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b1f5b26..1eea9f3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -48,6 +48,7 @@ jobs: - test_web.py - test_deep_research.py - test_ai_scrape.py + - test_vocr.py steps: - uses: actions/checkout@v4 diff --git a/jigsawstack/__init__.py b/jigsawstack/__init__.py index 091f775..9218810 100644 --- a/jigsawstack/__init__.py +++ b/jigsawstack/__init__.py @@ -21,24 +21,29 @@ class JigsawStack: + api_key: str + base_url: str + headers: Dict[str, str] audio: Audio - vision: Vision - image_generation: ImageGeneration - file: Store - web: Web - search: Search classification: Classification + embedding: Embedding + embedding_v2: EmbeddingV2 + store: Store + image_generation: ImageGeneration + prediction: Prediction prompt_engine: PromptEngine - api_key: str - api_url: str - headers: Dict[str, str] - # disable_request_logging: bool + sentiment: Sentiment + summary: Summary + text_to_sql: SQL + translate: Translate + validate: Validate + vision: Vision + web: Web def __init__( self, api_key: Union[str, None] = None, - api_url: Union[str, None] = None, - # disable_request_logging: Union[bool, None] = None, + base_url: Union[str, None] = None, headers: Union[Dict[str, str], None] = None, ) -> None: if api_key is None: @@ -49,117 +54,89 @@ def __init__( "The api_key client option must be set either by passing api_key to the client or by setting the JIGSAWSTACK_API_KEY environment variable" ) - if api_url is None: - api_url = os.environ.get("JIGSAWSTACK_API_URL") - if api_url is None: - api_url = "https://api.jigsawstack.com/" + if base_url is None: + base_url = os.environ.get("JIGSAWSTACK_API_URL") + if base_url is None: + base_url = "https://api.jigsawstack.com/" self.api_key = api_key - self.api_url = api_url + self.base_url = base_url - self.headers = headers or {} + self.headers = headers or {"Content-Type": "application/json"} - disable_request_logging = self.headers.get("x-jigsaw-no-request-log") + self.audio = Audio(api_key=api_key, base_url=base_url + "/v1", headers=headers) + + self.web = Web(api_key=api_key, base_url=base_url + "/v1", headers=headers) - self.audio = Audio( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, - ) - self.web = Web( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, - ) self.sentiment = Sentiment( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, + api_key=api_key, base_url=base_url + "/v1", headers=headers ).analyze - self.validate = Validate( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, - ) + + self.validate = Validate(api_key=api_key, base_url=base_url + "/v1", headers=headers) self.summary = Summary( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, + api_key=api_key, base_url=base_url + "/v1", headers=headers ).summarize - self.vision = Vision( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, - ) + + self.vision = Vision(api_key=api_key, base_url=base_url + "/v1", headers=headers) + self.prediction = Prediction( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, + api_key=api_key, base_url=base_url + "/v1", headers=headers ).predict + self.text_to_sql = SQL( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, + api_key=api_key, base_url=base_url + "/v1", headers=headers ).text_to_sql - self.store = Store( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, - ) - self.translate = Translate( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, - ) + + self.store = Store(api_key=api_key, base_url=base_url + "/v1", headers=headers) + + self.translate = Translate(api_key=api_key, base_url=base_url + "/v1", headers=headers) self.embedding = Embedding( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, + api_key=api_key, base_url=base_url + "/v1", headers=headers ).execute - self.embeddingV2 = EmbeddingV2( - api_key=api_key, - api_url=api_url + "/v2", - disable_request_logging=disable_request_logging, + self.embedding_v2 = EmbeddingV2( + api_key=api_key, base_url=base_url + "/v2", headers=headers ).execute self.image_generation = ImageGeneration( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, + api_key=api_key, base_url=base_url + "/v1", headers=headers ).image_generation self.classification = Classification( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, + api_key=api_key, base_url=base_url + "/v1", headers=headers ).classify self.prompt_engine = PromptEngine( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, + api_key=api_key, base_url=base_url + "/v1", headers=headers ) class AsyncJigsawStack: - validate: AsyncValidate - web: AsyncWeb + api_key: str + base_url: str + headers: Dict[str, str] audio: AsyncAudio - vision: AsyncVision + classification: AsyncClassification + embedding: AsyncEmbedding + embedding_v2: AsyncEmbeddingV2 image_generation: AsyncImageGeneration - store: AsyncStore + prediction: AsyncPrediction prompt_engine: AsyncPromptEngine - api_key: str - api_url: str - disable_request_logging: bool + sentiment: AsyncSentiment + store: AsyncStore + summary: AsyncSummary + text_to_sql: AsyncSQL + translate: AsyncTranslate + validate: AsyncValidate + vision: AsyncVision + web: AsyncWeb def __init__( self, api_key: Union[str, None] = None, - api_url: Union[str, None] = None, - disable_request_logging: Union[bool, None] = None, + base_url: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, ) -> None: if api_key is None: api_key = os.environ.get("JIGSAWSTACK_API_KEY") @@ -169,100 +146,59 @@ def __init__( "The api_key client option must be set either by passing api_key to the client or by setting the JIGSAWSTACK_API_KEY environment variable" ) - if api_url is None: - api_url = os.environ.get("JIGSAWSTACK_API_URL") - if api_url is None: - api_url = "https://api.jigsawstack.com/" + if base_url is None: + base_url = os.environ.get("JIGSAWSTACK_API_URL") + if base_url is None: + base_url = "https://api.jigsawstack.com/" self.api_key = api_key - self.api_url = api_url + self.base_url = base_url + self.headers = headers or {"Content-Type": "application/json"} - self.web = AsyncWeb( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, - ) + self.web = AsyncWeb(api_key=api_key, base_url=base_url + "/v1", headers=headers) - self.validate = AsyncValidate( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, - ) - self.audio = AsyncAudio( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, - ) + self.validate = AsyncValidate(api_key=api_key, base_url=base_url + "/v1", headers=headers) - self.vision = AsyncVision( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, - ) + self.audio = AsyncAudio(api_key=api_key, base_url=base_url + "/v1", headers=headers) - self.store = AsyncStore( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, - ) + self.vision = AsyncVision(api_key=api_key, base_url=base_url + "/v1", headers=headers) + + self.store = AsyncStore(api_key=api_key, base_url=base_url + "/v1", headers=headers) self.summary = AsyncSummary( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, + api_key=api_key, base_url=base_url + "/v1", headers=headers ).summarize - self.prediction = AsyncPrediction( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, - ).predict + self.prediction = AsyncPrediction(api_key=api_key, base_url=base_url + "/v1").predict + self.text_to_sql = AsyncSQL( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, + api_key=api_key, base_url=base_url + "/v1", headers=headers ).text_to_sql self.sentiment = AsyncSentiment( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, + api_key=api_key, base_url=base_url + "/v1", headers=headers ).analyze - self.translate = AsyncTranslate( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, - ) + self.translate = AsyncTranslate(api_key=api_key, base_url=base_url + "/v1", headers=headers) self.embedding = AsyncEmbedding( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, + api_key=api_key, base_url=base_url + "/v1", headers=headers ).execute - self.embeddingV2 = AsyncEmbeddingV2( - api_key=api_key, - api_url=api_url + "/v2", - disable_request_logging=disable_request_logging, + self.embedding_v2 = AsyncEmbeddingV2( + api_key=api_key, base_url=base_url + "/v2", headers=headers ).execute self.image_generation = AsyncImageGeneration( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, + api_key=api_key, base_url=base_url + "/v1", headers=headers ).image_generation self.classification = AsyncClassification( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, + api_key=api_key, base_url=base_url + "/v1", headers=headers ).classify self.prompt_engine = AsyncPromptEngine( - api_key=api_key, - api_url=api_url + "/v1", - disable_request_logging=disable_request_logging, + api_key=api_key, base_url=base_url + "/v1", headers=headers ) diff --git a/jigsawstack/_config.py b/jigsawstack/_config.py index 6e15b54..3a007d8 100644 --- a/jigsawstack/_config.py +++ b/jigsawstack/_config.py @@ -1,17 +1,17 @@ -from typing import Union +from typing import Dict, Union class ClientConfig: base_url: str api_key: str - disable_request_logging: Union[bool, None] = None + headers: Union[Dict[str, str], None] def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = None, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): self.api_key = api_key - self.api_url = api_url - self.disable_request_logging = disable_request_logging + self.base_url = base_url + self.headers = headers diff --git a/jigsawstack/async_request.py b/jigsawstack/async_request.py index 26a7e53..d8f530d 100644 --- a/jigsawstack/async_request.py +++ b/jigsawstack/async_request.py @@ -13,9 +13,9 @@ class AsyncRequestConfig(TypedDict): - api_url: str + base_url: str api_key: str - disable_request_logging: Union[bool, None] = False + headers: Union[Dict[str, str], None] class AsyncRequest(Generic[T]): @@ -25,19 +25,19 @@ def __init__( path: str, params: Union[Dict[Any, Any], List[Dict[Any, Any]]], verb: RequestVerb, - headers: Dict[str, str] = None, data: Union[bytes, None] = None, stream: Union[bool, None] = False, + files: Union[Dict[str, Any], None] = None, # Add files parameter ): self.path = path self.params = params self.verb = verb - self.api_url = config.get("api_url") + self.base_url = config.get("base_url") self.api_key = config.get("api_key") self.data = data - self.headers = headers or {"Content-Type": "application/json"} - self.disable_request_logging = config.get("disable_request_logging") + self.headers = config.get("headers", None) or {"Content-Type": "application/json"} self.stream = stream + self.files = files # Store files for multipart requests def __convert_params( self, params: Union[Dict[Any, Any], List[Dict[Any, Any]]] @@ -67,7 +67,7 @@ async def perform(self) -> Union[T, None]: Async method to make an HTTP request to the JigsawStack API. """ async with self.__get_session() as session: - resp = await self.make_request(session, url=f"{self.api_url}{self.path}") + resp = await self.make_request(session, url=f"{self.base_url}{self.path}") # For binary responses if resp.status == 200: @@ -108,7 +108,7 @@ async def perform(self) -> Union[T, None]: async def perform_file(self) -> Union[T, None]: async with self.__get_session() as session: - resp = await self.make_request(session, url=f"{self.api_url}{self.path}") + resp = await self.make_request(session, url=f"{self.base_url}{self.path}") if resp.status != 200: try: @@ -171,15 +171,20 @@ def __get_headers(self) -> Dict[str, str]: Dict[str, str]: Configured HTTP Headers """ h = { - "Content-Type": "application/json", "Accept": "application/json", "x-api-key": f"{self.api_key}", } - if self.disable_request_logging: - h["x-jigsaw-no-request-log"] = "true" + # only add Content-Type if not using multipart (files) + if not self.files and not self.data: + h["Content-Type"] = "application/json" _headers = h.copy() + + # don't override Content-Type if using multipart + if self.files and "Content-Type" in self.headers: + self.headers.pop("Content-Type") + _headers.update(self.headers) return _headers @@ -192,7 +197,7 @@ async def perform_streaming(self) -> AsyncGenerator[Union[T, str], None]: AsyncGenerator[Union[T, str], None]: A generator of response chunks """ async with self.__get_session() as session: - resp = await self.make_request(session, url=f"{self.api_url}{self.path}") + resp = await self.make_request(session, url=f"{self.base_url}{self.path}") # delete calls do not return a body if await resp.text() == "": @@ -231,50 +236,35 @@ async def make_request( self, session: aiohttp.ClientSession, url: str ) -> aiohttp.ClientResponse: headers = self.__get_headers() + params = self.params verb = self.verb - data = self.data + files = self.files - # Convert params to string values for URL encoding - converted_params = self.__convert_params(self.params) + _params = None + _json = None + _data = None + _form_data = None if verb.lower() in ["get", "delete"]: - return await session.request( - verb, - url, - params=converted_params, - headers=headers, - ) - else: - if data is not None: - form_data = aiohttp.FormData() - form_data.add_field( - "file", - BytesIO(data), - content_type=headers.get("Content-Type", "application/octet-stream"), - filename="file", - ) - - if self.params and isinstance(self.params, dict): - form_data.add_field( - "body", json.dumps(self.params), content_type="application/json" - ) - - multipart_headers = headers.copy() - multipart_headers.pop("Content-Type", None) - - return await session.request( - verb, - url, - data=form_data, - headers=multipart_headers, - ) - else: - return await session.request( - verb, - url, - json=self.params, # Keep JSON body as original - headers=headers, - ) + _params = self.__convert_params(params) + elif files: + _form_data = aiohttp.FormData() + _form_data.add_field("file", BytesIO(files["file"]), filename="upload") + if params and isinstance(params, dict): + _form_data.add_field("body", json.dumps(params), content_type="application/json") + + headers.pop("Content-Type", None) + else: # pure JSON request + _json = params + + return await session.request( + verb, + url, + params=_params, + json=_json, + data=_form_data or _data, + headers=headers, + ) def __get_session(self) -> aiohttp.ClientSession: """ diff --git a/jigsawstack/audio.py b/jigsawstack/audio.py index cadfd25..575b839 100644 --- a/jigsawstack/audio.py +++ b/jigsawstack/audio.py @@ -54,15 +54,11 @@ class Audio(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) - self.config = RequestConfig( - api_url=api_url, - api_key=api_key, - disable_request_logging=disable_request_logging, - ) + super().__init__(api_key, base_url, headers) + self.config = RequestConfig(base_url=base_url, api_key=api_key, headers=headers) @overload def speech_to_text( @@ -80,11 +76,8 @@ def speech_to_text( ) -> Union[SpeechToTextResponse, SpeechToTextWebhookResponse]: options = options or {} path = "/ai/transcribe" - content_type = options.get("content_type", "application/octet-stream") - headers = {"Content-Type": content_type} - if isinstance( - blob, dict - ): # If params is provided as a dict, we assume it's the first argument + if isinstance(blob, dict): + # URL or file_store_key based request resp = Request( config=self.config, path=path, @@ -93,13 +86,13 @@ def speech_to_text( ).perform_with_content() return resp + files = {"file": blob} resp = Request( config=self.config, path=path, params=options, - data=blob, - headers=headers, verb="post", + files=files, ).perform_with_content() return resp @@ -110,14 +103,14 @@ class AsyncAudio(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = AsyncRequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) @overload @@ -136,8 +129,6 @@ async def speech_to_text( ) -> Union[SpeechToTextResponse, SpeechToTextWebhookResponse]: options = options or {} path = "/ai/transcribe" - content_type = options.get("content_type", "application/octet-stream") - headers = {"Content-Type": content_type} if isinstance(blob, dict): resp = await AsyncRequest( config=self.config, @@ -147,12 +138,12 @@ async def speech_to_text( ).perform_with_content() return resp + files = {"file": blob} resp = await AsyncRequest( config=self.config, path=path, params=options, - data=blob, - headers=headers, verb="post", + files=files, ).perform_with_content() return resp diff --git a/jigsawstack/classification.py b/jigsawstack/classification.py index 45407e9..134307c 100644 --- a/jigsawstack/classification.py +++ b/jigsawstack/classification.py @@ -67,14 +67,14 @@ class Classification(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) def classify(self, params: ClassificationParams) -> ClassificationResponse: @@ -94,14 +94,14 @@ class AsyncClassification(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = AsyncRequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) async def classify(self, params: ClassificationParams) -> ClassificationResponse: diff --git a/jigsawstack/embedding.py b/jigsawstack/embedding.py index cd755f0..d914c4c 100644 --- a/jigsawstack/embedding.py +++ b/jigsawstack/embedding.py @@ -5,7 +5,6 @@ from ._config import ClientConfig from ._types import BaseResponse from .async_request import AsyncRequest -from .helpers import build_path from .request import Request, RequestConfig @@ -34,14 +33,14 @@ class Embedding(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) @overload @@ -55,6 +54,7 @@ def execute( options: EmbeddingParams = None, ) -> EmbeddingResponse: path = "/embedding" + options = options or {} if isinstance(blob, dict): resp = Request( config=self.config, @@ -64,17 +64,12 @@ def execute( ).perform_with_content() return resp - options = options or {} - path = build_path(base_path=path, params=options) - content_type = options.get("content_type", "application/octet-stream") - _headers = {"Content-Type": content_type} - + files = {"file": blob} resp = Request( config=self.config, path=path, params=options, - data=blob, - headers=_headers, + files=files, verb="post", ).perform_with_content() return resp @@ -86,14 +81,14 @@ class AsyncEmbedding(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) @overload @@ -107,6 +102,7 @@ async def execute( options: EmbeddingParams = None, ) -> EmbeddingResponse: path = "/embedding" + options = options or {} if isinstance(blob, dict): resp = await AsyncRequest( config=self.config, @@ -116,17 +112,12 @@ async def execute( ).perform_with_content() return resp - options = options or {} - path = build_path(base_path=path, params=options) - content_type = options.get("content_type", "application/octet-stream") - _headers = {"Content-Type": content_type} - + files = {"file": blob} resp = await AsyncRequest( config=self.config, path=path, params=options, - data=blob, - headers=_headers, + files=files, verb="post", ).perform_with_content() return resp diff --git a/jigsawstack/embedding_v2.py b/jigsawstack/embedding_v2.py index fe62f69..685cd52 100644 --- a/jigsawstack/embedding_v2.py +++ b/jigsawstack/embedding_v2.py @@ -5,7 +5,6 @@ from ._config import ClientConfig from .async_request import AsyncRequest from .embedding import Chunk -from .helpers import build_path from .request import Request, RequestConfig @@ -32,14 +31,14 @@ class EmbeddingV2(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) @overload @@ -53,6 +52,7 @@ def execute( options: EmbeddingV2Params = None, ) -> EmbeddingV2Response: path = "/embedding" + options = options or {} if isinstance(blob, dict): resp = Request( config=self.config, @@ -62,17 +62,12 @@ def execute( ).perform_with_content() return resp - options = options or {} - path = build_path(base_path=path, params=options) - content_type = options.get("content_type", "application/octet-stream") - _headers = {"Content-Type": content_type} - + files = {"file": blob} resp = Request( config=self.config, path=path, params=options, - data=blob, - headers=_headers, + files=files, verb="post", ).perform_with_content() return resp @@ -84,14 +79,14 @@ class AsyncEmbeddingV2(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) @overload @@ -107,6 +102,7 @@ async def execute( options: EmbeddingV2Params = None, ) -> EmbeddingV2Response: path = "/embedding" + options = options or {} if isinstance(blob, dict): resp = await AsyncRequest( config=self.config, @@ -116,17 +112,12 @@ async def execute( ).perform_with_content() return resp - options = options or {} - path = build_path(base_path=path, params=options) - content_type = options.get("content_type", "application/octet-stream") - _headers = {"Content-Type": content_type} - + files = {"file": blob} resp = await AsyncRequest( config=self.config, path=path, params=options, - data=blob, - headers=_headers, + files=files, verb="post", ).perform_with_content() return resp diff --git a/jigsawstack/image_generation.py b/jigsawstack/image_generation.py index 9584cf3..08cf81c 100644 --- a/jigsawstack/image_generation.py +++ b/jigsawstack/image_generation.py @@ -89,14 +89,14 @@ class ImageGeneration(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging=disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) def image_generation( @@ -118,14 +118,14 @@ class AsyncImageGeneration(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging=disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) async def image_generation( diff --git a/jigsawstack/prediction.py b/jigsawstack/prediction.py index ec571a4..00bd3cf 100644 --- a/jigsawstack/prediction.py +++ b/jigsawstack/prediction.py @@ -48,14 +48,14 @@ class Prediction(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) def predict(self, params: PredictionParams) -> PredictionResponse: @@ -75,14 +75,14 @@ class AsyncPrediction(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) async def predict(self, params: PredictionParams) -> PredictionResponse: diff --git a/jigsawstack/prompt_engine.py b/jigsawstack/prompt_engine.py index 3af7fa3..c264db9 100644 --- a/jigsawstack/prompt_engine.py +++ b/jigsawstack/prompt_engine.py @@ -97,14 +97,14 @@ class PromptEngine(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) def create(self, params: PromptEngineCreateParams) -> PromptEngineCreateResponse: @@ -203,14 +203,14 @@ class AsyncPromptEngine(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) async def create(self, params: PromptEngineCreateParams) -> PromptEngineCreateResponse: diff --git a/jigsawstack/request.py b/jigsawstack/request.py index c1967a4..38cbf01 100644 --- a/jigsawstack/request.py +++ b/jigsawstack/request.py @@ -12,9 +12,9 @@ class RequestConfig(TypedDict): - api_url: str + base_url: str api_key: str - disable_request_logging: Union[bool, None] = False + headers: Union[Dict[str, str], None] # This class wraps the HTTP request creation logic @@ -25,19 +25,19 @@ def __init__( path: str, params: Union[Dict[Any, Any], List[Dict[Any, Any]]], verb: RequestVerb, - headers: Dict[str, str] = None, data: Union[bytes, None] = None, stream: Union[bool, None] = False, + files: Union[Dict[str, Any], None] = None, ): self.path = path self.params = params self.verb = verb - self.api_url = config.get("api_url") + self.base_url = config.get("base_url") self.api_key = config.get("api_key") self.data = data - self.headers = headers or {"Content-Type": "application/json"} - self.disable_request_logging = config.get("disable_request_logging") + self.headers = config.get("headers", None) or {"Content-Type": "application/json"} self.stream = stream + self.files = files def perform(self) -> Union[T, None]: """Is the main function that makes the HTTP request @@ -50,7 +50,7 @@ def perform(self) -> Union[T, None]: Raises: requests.HTTPError: If the request fails """ - resp = self.make_request(url=f"{self.api_url}{self.path}") + resp = self.make_request(url=f"{self.base_url}{self.path}") # for binary responses if resp.status_code == 200: @@ -83,7 +83,7 @@ def perform(self) -> Union[T, None]: return cast(T, resp) def perform_file(self) -> Union[T, None]: - resp = self.make_request(url=f"{self.api_url}{self.path}") + resp = self.make_request(url=f"{self.base_url}{self.path}") # delete calls do not return a body if resp.text == "" and resp.status_code == 200: @@ -152,15 +152,20 @@ def __get_headers(self) -> Dict[Any, Any]: """ h = { - "Content-Type": "application/json", "Accept": "application/json", "x-api-key": f"{self.api_key}", } - if self.disable_request_logging: - h["x-jigsaw-no-request-log"] = "true" + # Only add Content-Type if not using multipart (files) + if not self.files and not self.data: + h["Content-Type"] = "application/json" _headers = h.copy() + + # Don't override Content-Type if using multipart + if self.files and "Content-Type" in self.headers: + self.headers.pop("Content-Type") + _headers.update(self.headers) return _headers @@ -176,7 +181,7 @@ def perform_streaming(self) -> Generator[Union[T, str], None, None]: Raises: requests.HTTPError: If the request fails """ - resp = self.make_request(url=f"{self.api_url}{self.path}") + resp = self.make_request(url=f"{self.base_url}{self.path}") # delete calls do not return a body if resp.text == "": @@ -242,21 +247,32 @@ def make_request(self, url: str) -> requests.Response: headers = self.__get_headers() params = self.params verb = self.verb - data = self.data + files = self.files _requestParams = None + _json = None + _data = None + _files = None if verb.lower() in ["get", "delete"]: _requestParams = params - + elif files: # multipart request + _files = files + if params and isinstance(params, dict): + _data = {"body": json.dumps(params)} + headers.pop("Content-Type", None) # let requests set it for multipart + + else: # pure JSON request + _json = params try: return requests.request( verb, url, params=_requestParams, - json=params, + json=_json, headers=headers, - data=data, + data=_data, + files=_files, stream=self.stream, ) except requests.HTTPError as e: diff --git a/jigsawstack/search.py b/jigsawstack/search.py index 21b0187..7898f8b 100644 --- a/jigsawstack/search.py +++ b/jigsawstack/search.py @@ -225,14 +225,14 @@ class Search(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) def search(self, params: SearchParams) -> SearchResponse: @@ -287,14 +287,14 @@ class AsyncSearch(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = AsyncRequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) async def search(self, params: SearchParams) -> SearchResponse: diff --git a/jigsawstack/sentiment.py b/jigsawstack/sentiment.py index ef5e9df..7bc1acb 100644 --- a/jigsawstack/sentiment.py +++ b/jigsawstack/sentiment.py @@ -48,14 +48,14 @@ class Sentiment(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) def analyze(self, params: SentimentParams) -> SentimentResponse: @@ -75,14 +75,14 @@ class AsyncSentiment(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) async def analyze(self, params: SentimentParams) -> SentimentResponse: diff --git a/jigsawstack/sql.py b/jigsawstack/sql.py index b895485..4fa0ac8 100644 --- a/jigsawstack/sql.py +++ b/jigsawstack/sql.py @@ -43,14 +43,14 @@ class SQL(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) def text_to_sql(self, params: SQLParams) -> SQLResponse: @@ -70,14 +70,14 @@ class AsyncSQL(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) async def text_to_sql(self, params: SQLParams) -> SQLResponse: diff --git a/jigsawstack/store.py b/jigsawstack/store.py index 0693f49..89facea 100644 --- a/jigsawstack/store.py +++ b/jigsawstack/store.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Any, Dict, Union from typing_extensions import NotRequired, TypedDict @@ -32,14 +32,14 @@ class Store(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) def upload( @@ -51,14 +51,16 @@ def upload( path = build_path(base_path="/store/file", params=options) content_type = options.get("content_type", "application/octet-stream") - _headers = {"Content-Type": content_type} + config_with_headers = self.config.copy() + if config_with_headers.get("headers") is None: + config_with_headers["headers"] = {} + config_with_headers["headers"]["Content-Type"] = content_type resp = Request( - config=self.config, - params=options, # Empty params since we're using them in the URL + config=config_with_headers, + params={}, path=path, data=file, - headers=_headers, verb="post", ).perform_with_content() return resp @@ -90,14 +92,14 @@ class AsyncStore(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = AsyncRequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) async def upload( @@ -108,13 +110,17 @@ async def upload( path = build_path(base_path="/store/file", params=options) content_type = options.get("content_type", "application/octet-stream") - _headers = {"Content-Type": content_type} + + config_with_headers = self.config.copy() + if config_with_headers.get("headers") is None: + config_with_headers["headers"] = {} + config_with_headers["headers"]["Content-Type"] = content_type + resp = await AsyncRequest( - config=self.config, - params=options, # Empty params since we're using them in the URL + config=config_with_headers, + params={}, path=path, data=file, - headers=_headers, verb="post", ).perform_with_content() return resp diff --git a/jigsawstack/summary.py b/jigsawstack/summary.py index 0d19b39..48fe578 100644 --- a/jigsawstack/summary.py +++ b/jigsawstack/summary.py @@ -53,14 +53,14 @@ class Summary(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) def summarize(self, params: SummaryParams) -> SummaryResponse: @@ -80,14 +80,14 @@ class AsyncSummary(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) async def summarize(self, params: SummaryParams) -> SummaryResponse: diff --git a/jigsawstack/translate.py b/jigsawstack/translate.py index 63b7fa5..b609540 100644 --- a/jigsawstack/translate.py +++ b/jigsawstack/translate.py @@ -5,7 +5,6 @@ from ._config import ClientConfig from ._types import BaseResponse from .async_request import AsyncRequest -from .helpers import build_path from .request import Request, RequestConfig @@ -64,14 +63,14 @@ class Translate(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) def text(self, params: TranslateParams) -> TranslateResponse: @@ -95,6 +94,8 @@ def image( blob: Union[TranslateImageParams, bytes], options: TranslateImageParams = None, ) -> Union[TranslateImageResponse, bytes]: + path = "/ai/translate/image" + options = options or {} if isinstance( blob, dict ): # If params is provided as a dict, we assume it's the first argument @@ -106,17 +107,12 @@ def image( ).perform_with_content() return resp - options = options or {} - path = build_path(base_path="/ai/translate/image", params=options) - content_type = options.get("content_type", "application/octet-stream") - headers = {"Content-Type": content_type} - + files = {"file": blob} resp = Request( config=self.config, path=path, params=options, - data=blob, - headers=headers, + files=files, verb="post", ).perform_with_content() return resp @@ -128,14 +124,14 @@ class AsyncTranslate(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) async def text(self, params: TranslateParams) -> TranslateResponse: @@ -159,6 +155,8 @@ async def image( blob: Union[TranslateImageParams, bytes], options: TranslateImageParams = None, ) -> Union[TranslateImageResponse, bytes]: + path = "/ai/translate/image" + options = options or {} if isinstance(blob, dict): resp = await AsyncRequest( config=self.config, @@ -168,17 +166,12 @@ async def image( ).perform_with_content() return resp - options = options or {} - path = build_path(base_path="/ai/translate/image", params=options) - content_type = options.get("content_type", "application/octet-stream") - headers = {"Content-Type": content_type} - + files = {"file": blob} resp = await AsyncRequest( config=self.config, path=path, params=options, - data=blob, - headers=headers, + files=files, verb="post", ).perform_with_content() return resp diff --git a/jigsawstack/validate.py b/jigsawstack/validate.py index fc57c3c..d40cf55 100644 --- a/jigsawstack/validate.py +++ b/jigsawstack/validate.py @@ -79,14 +79,14 @@ class Validate(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) @overload @@ -99,6 +99,8 @@ def nsfw( blob: Union[NSFWParams, bytes], options: NSFWParams = None, ) -> NSFWResponse: + path = "/validate/nsfw" + options = options or {} if isinstance( blob, dict ): # If params is provided as a dict, we assume it's the first argument @@ -110,17 +112,12 @@ def nsfw( ).perform_with_content() return resp - options = options or {} - path = build_path(base_path="/validate/nsfw", params=options) - content_type = options.get("content_type", "application/octet-stream") - headers = {"Content-Type": content_type} - + files = {"file": blob} resp = Request( config=self.config, path=path, params=options, - data=blob, - headers=headers, + files=files, verb="post", ).perform_with_content() return resp @@ -168,14 +165,14 @@ class AsyncValidate(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = AsyncRequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) @overload @@ -188,28 +185,25 @@ async def nsfw( blob: Union[NSFWParams, bytes], options: NSFWParams = None, ) -> NSFWResponse: + path = "/validate/nsfw" + options = options or {} if isinstance( blob, dict ): # If params is provided as a dict, we assume it's the first argument resp = await AsyncRequest( config=self.config, - path="/validate/nsfw", + path=path, params=cast(Dict[Any, Any], blob), verb="post", ).perform_with_content() return resp - options = options or {} - path = build_path(base_path="/validate/nsfw", params=options) - content_type = options.get("content_type", "application/octet-stream") - headers = {"Content-Type": content_type} - + files = {"file": blob} resp = await AsyncRequest( config=self.config, path=path, params=options, - data=blob, - headers=headers, + files=files, verb="post", ).perform_with_content() return resp diff --git a/jigsawstack/version.py b/jigsawstack/version.py index 44573a9..95b9715 100644 --- a/jigsawstack/version.py +++ b/jigsawstack/version.py @@ -1,4 +1,4 @@ -__version__ = "0.3.3" +__version__ = "0.3.4" def get_version() -> str: diff --git a/jigsawstack/vision.py b/jigsawstack/vision.py index 6df4e37..280b71d 100644 --- a/jigsawstack/vision.py +++ b/jigsawstack/vision.py @@ -159,10 +159,10 @@ class OCRResponse(BaseResponse): tags: List[str] has_text: bool sections: List[object] - total_pages: Optional[int] # Only available for PDFs - page_ranges: Optional[ + total_pages: Optional[int] + page_range: Optional[ List[int] - ] # Only available if page_ranges is set in the request parameters. + ] # Only available if page_range is set in the request parameters. class Vision(ClientConfig): @@ -171,14 +171,14 @@ class Vision(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) @overload @@ -204,15 +204,12 @@ def vocr( ).perform_with_content() return resp - content_type = options.get("content_type", "application/octet-stream") - headers = {"Content-Type": content_type} - + files = {"file": blob} resp = Request( config=self.config, path=path, params=options, - data=blob, - headers=headers, + files=files, verb="post", ).perform_with_content() return resp @@ -239,16 +236,12 @@ def object_detection( verb="post", ).perform_with_content() return resp - - content_type = options.get("content_type", "application/octet-stream") - headers = {"Content-Type": content_type} - + files = {"file": blob} resp = Request( config=self.config, path=path, params=options, - data=blob, - headers=headers, + files=files, verb="post", ).perform_with_content() return resp @@ -260,14 +253,14 @@ class AsyncVision(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = AsyncRequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) @overload @@ -291,15 +284,12 @@ async def vocr( ).perform_with_content() return resp - content_type = options.get("content_type", "application/octet-stream") - headers = {"Content-Type": content_type} - + files = {"file": blob} resp = await AsyncRequest( config=self.config, path=path, params=options, - data=blob, - headers=headers, + files=files, verb="post", ).perform_with_content() return resp @@ -329,15 +319,12 @@ async def object_detection( ).perform_with_content() return resp - content_type = options.get("content_type", "application/octet-stream") - headers = {"Content-Type": content_type} - + files = {"file": blob} resp = await AsyncRequest( config=self.config, path=path, params=options, - data=blob, - headers=headers, + files=files, verb="post", ).perform_with_content() return resp diff --git a/jigsawstack/web.py b/jigsawstack/web.py index 5d400c3..d432c25 100644 --- a/jigsawstack/web.py +++ b/jigsawstack/web.py @@ -198,14 +198,14 @@ class Web(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = RequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) def ai_scrape(self, params: AIScrapeParams) -> AIScrapeResponse: @@ -248,27 +248,15 @@ def html_to_any( return cast(HTMLToAnyURLResponse, resp) def search(self, params: SearchParams) -> SearchResponse: - s = Search( - self.api_key, - self.api_url, - disable_request_logging=self.config.get("disable_request_logging"), - ) + s = Search(self.api_key, self.base_url, self.headers) return s.search(params) def search_suggestions(self, params: SearchSuggestionsParams) -> SearchSuggestionsResponse: - s = Search( - self.api_key, - self.api_url, - disable_request_logging=self.config.get("disable_request_logging"), - ) + s = Search(self.api_key, self.base_url, self.headers) return s.suggestions(params) def deep_research(self, params: DeepResearchParams) -> DeepResearchResponse: - s = Search( - self.api_key, - self.api_url, - disable_request_logging=self.config.get("disable_request_logging"), - ) + s = Search(self.api_key, self.base_url, self.headers) return s.deep_research(params) @@ -281,14 +269,14 @@ class AsyncWeb(ClientConfig): def __init__( self, api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, + base_url: str, + headers: Union[Dict[str, str], None] = None, ): - super().__init__(api_key, api_url, disable_request_logging) + super().__init__(api_key, base_url, headers) self.config = AsyncRequestConfig( - api_url=api_url, + base_url=base_url, api_key=api_key, - disable_request_logging=disable_request_logging, + headers=headers, ) async def ai_scrape(self, params: AIScrapeParams) -> AIScrapeResponse: @@ -331,27 +319,15 @@ async def html_to_any( return cast(HTMLToAnyURLResponse, resp) async def search(self, params: SearchParams) -> SearchResponse: - s = AsyncSearch( - self.api_key, - self.api_url, - disable_request_logging=self.config.get("disable_request_logging"), - ) + s = AsyncSearch(self.api_key, self.base_url, self.headers) return await s.search(params) async def search_suggestions( self, params: SearchSuggestionsParams ) -> SearchSuggestionsResponse: - s = AsyncSearch( - self.api_key, - self.api_url, - disable_request_logging=self.config.get("disable_request_logging"), - ) + s = AsyncSearch(self.api_key, self.base_url, self.headers) return await s.suggestions(params) async def deep_research(self, params: DeepResearchParams) -> DeepResearchResponse: - s = AsyncSearch( - self.api_key, - self.api_url, - disable_request_logging=self.config.get("disable_request_logging"), - ) + s = AsyncSearch(self.api_key, self.base_url, self.headers) return await s.deep_research(params) diff --git a/setup.py b/setup.py index bfb1aff..1aebb49 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="jigsawstack", - version="0.3.3", + version="0.3.4", description="JigsawStack - The AI SDK for Python", long_description=open("README.md", encoding="utf8").read(), long_description_content_type="text/markdown", @@ -19,7 +19,7 @@ python_requires=">=3.7", keywords=["AI", "AI Tooling"], setup_requires=["pytest-runner"], - tests_require=["pytest"], + tests_require=["pytest", "pytest-asyncio"], test_suite="tests", classifiers=[ "Development Status :: 4 - Beta", diff --git a/tests/test_audio.py b/tests/test_audio.py index 037f285..309b191 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -15,7 +15,6 @@ jigsaw = jigsawstack.JigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) async_jigsaw = jigsawstack.AsyncJigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) -# Sample audio URLs for testing AUDIO_URL = AUDIO_URL_LONG = "https://jigsawstack.com/preview/stt-example.wav" @@ -98,7 +97,10 @@ WEBHOOK_TEST_CASES = [ { "name": "with_webhook_url", - "params": {"url": AUDIO_URL, "webhook_url": "https://webhook.site/test-webhook"}, + "params": { + "url": AUDIO_URL, + "webhook_url": "https://webhook.site/test-webhook", + }, "blob": None, "options": None, }, @@ -106,7 +108,10 @@ "name": "with_blob_and_webhook", "params": None, "blob": AUDIO_URL, - "options": {"webhook_url": "https://webhook.site/test-webhook", "language": "en"}, + "options": { + "webhook_url": "https://webhook.site/test-webhook", + "language": "en", + }, }, ] diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 7b6b368..c2bc59d 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -246,7 +246,7 @@ class TestEmbeddingV2Sync: def test_embedding_v2(self, test_case): """Test synchronous embedding v2 with various inputs""" try: - result = jigsaw.embeddingV2(test_case["params"]) + result = jigsaw.embedding_v2(test_case["params"]) assert result["success"] assert "embeddings" in result assert isinstance(result["embeddings"], list) @@ -271,7 +271,7 @@ def test_embedding_v2_blob(self, test_case): try: # Download blob content blob_content = requests.get(test_case["blob_url"]).content - result = jigsaw.embeddingV2(blob_content, test_case["options"]) + result = jigsaw.embedding_v2(blob_content, test_case["options"]) assert result["success"] assert "embeddings" in result assert isinstance(result["embeddings"], list) @@ -291,7 +291,7 @@ class TestEmbeddingV2Async: async def test_embedding_v2_async(self, test_case): """Test asynchronous embedding v2 with various inputs""" try: - result = await async_jigsaw.embeddingV2(test_case["params"]) + result = await async_jigsaw.embedding_v2(test_case["params"]) assert result["success"] assert "embeddings" in result assert isinstance(result["embeddings"], list) @@ -317,7 +317,7 @@ async def test_embedding_v2_blob_async(self, test_case): try: # Download blob content blob_content = requests.get(test_case["blob_url"]).content - result = await async_jigsaw.embeddingV2(blob_content, test_case["options"]) + result = await async_jigsaw.embedding_v2(blob_content, test_case["options"]) assert result["success"] assert "embeddings" in result assert isinstance(result["embeddings"], list) diff --git a/tests/test_vocr.py b/tests/test_vocr.py new file mode 100644 index 0000000..d233484 --- /dev/null +++ b/tests/test_vocr.py @@ -0,0 +1,241 @@ +import logging +import os + +import pytest +import requests +from dotenv import load_dotenv + +import jigsawstack +from jigsawstack.exceptions import JigsawStackError + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +jigsaw = jigsawstack.JigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) +async_jigsaw = jigsawstack.AsyncJigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) + +IMAGE_URL = "https://jigsawstack.com/preview/vocr-example.jpg" + +# PDF URL for testing page_range functionality +PDF_URL = "https://arxiv.org/pdf/1706.03762" + +TEST_CASES = [ + { + "name": "with_url_only", + "params": {"url": IMAGE_URL}, + "blob": None, + "options": None, + }, + { + "name": "with_blob_only", + "params": None, + "blob": IMAGE_URL, + "options": None, + }, + { + "name": "with_string_prompt", + "blob": IMAGE_URL, + "options": {"prompt": "Extract all text from the image"}, + }, + { + "name": "with_list_prompt", + "blob": IMAGE_URL, + "options": { + "prompt": [ + "What is the main heading?", + "Extract any dates mentioned", + "What are the key points?", + ] + }, + }, + { + "name": "with_dict_prompt", + "blob": IMAGE_URL, + "options": { + "prompt": { + "title": "Extract the main title", + "content": "What is the main content?", + "metadata": "Extract any metadata or additional information", + } + }, + }, + { + "name": "url_with_string_prompt", + "params": {"url": IMAGE_URL, "prompt": "Summarize the text content"}, + "blob": None, + "options": None, + }, + { + "name": "url_with_list_prompt", + "params": {"url": IMAGE_URL, "prompt": ["Extract headers", "Extract body text"]}, + "blob": None, + "options": None, + }, +] + +# PDF specific test cases +PDF_TEST_CASES = [ + { + "name": "pdf_with_page_range", + "params": {"url": PDF_URL, "page_range": [1, 3], "prompt": "Extract text from these pages"}, + "blob": None, + "options": None, + }, + { + "name": "pdf_single_page", + "params": {"url": PDF_URL, "page_range": [1, 1], "prompt": "What is on the first page?"}, + "blob": None, + "options": None, + }, + { + "name": "pdf_blob_with_page_range", + "blob": PDF_URL, + "options": {"page_range": [1, 3], "prompt": "what is this about?"}, + }, +] + + +class TestVOCRSync: + """Test synchronous VOCR methods""" + + sync_test_cases = TEST_CASES + pdf_test_cases = PDF_TEST_CASES + + @pytest.mark.parametrize( + "test_case", sync_test_cases, ids=[tc["name"] for tc in sync_test_cases] + ) + def test_vocr(self, test_case): + """Test synchronous VOCR with various inputs""" + try: + if test_case.get("blob"): + # Download blob content + blob_content = requests.get(test_case["blob"]).content + result = jigsaw.vision.vocr(blob_content, test_case.get("options", {})) + else: + # Use params directly + result = jigsaw.vision.vocr(test_case["params"]) + + print(f"Test {test_case['name']}: Success={result.get('success')}") + + # Verify response structure + assert result["success"] is True + if "prompt" in (test_case.get("params") or {}): + assert "context" in result + assert "width" in result + assert "height" in result + assert "has_text" in result + assert "tags" in result + assert isinstance(result["tags"], list) + assert "sections" in result + assert isinstance(result["sections"], list) + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + @pytest.mark.parametrize("test_case", pdf_test_cases, ids=[tc["name"] for tc in pdf_test_cases]) + def test_vocr_pdf(self, test_case): + """Test synchronous VOCR with PDF inputs""" + try: + if test_case.get("blob"): + # Download blob content + blob_content = requests.get(test_case["blob"]).content + result = jigsaw.vision.vocr(blob_content, test_case.get("options", {})) + else: + # Use params directly + result = jigsaw.vision.vocr(test_case["params"]) + + # Verify response structure + assert result["success"] is True + if "prompt" in (test_case.get("params") or {}): + assert "context" in result + assert "total_pages" in result + + if test_case.get("params", {}).get("page_range") or test_case.get("options", {}).get( + "page_range" + ): + assert "page_range" in result + assert isinstance(result["page_range"], list) + + logger.info(f"Test {test_case['name']}: total_pages={result.get('total_pages')}") + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestVOCRAsync: + """Test asynchronous VOCR methods""" + + async_test_cases = TEST_CASES + pdf_test_cases = PDF_TEST_CASES + + @pytest.mark.parametrize( + "test_case", async_test_cases, ids=[tc["name"] for tc in async_test_cases] + ) + @pytest.mark.asyncio + async def test_vocr_async(self, test_case): + """Test asynchronous VOCR with various inputs""" + try: + if test_case.get("blob"): + # Download blob content + blob_content = requests.get(test_case["blob"]).content + result = await async_jigsaw.vision.vocr(blob_content, test_case.get("options", {})) + else: + # Use params directly + result = await async_jigsaw.vision.vocr(test_case["params"]) + + print(f"Test {test_case['name']}: Success={result.get('success')}") + + # Verify response structure + assert result["success"] is True + if "prompt" in (test_case.get("params") or {}): + assert "context" in result + assert "width" in result + assert "height" in result + assert "has_text" in result + assert "tags" in result + assert isinstance(result["tags"], list) + assert "sections" in result + assert isinstance(result["sections"], list) + + # Log some details + logger.info( + f"Test {test_case['name']}: has_text={result['has_text']}, tags={result['tags'][:3] if result['tags'] else []}" + ) + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + @pytest.mark.parametrize("test_case", pdf_test_cases, ids=[tc["name"] for tc in pdf_test_cases]) + @pytest.mark.asyncio + async def test_vocr_pdf_async(self, test_case): + """Test asynchronous VOCR with PDF inputs""" + try: + if test_case.get("blob"): + # Download blob content + blob_content = requests.get(test_case["blob"]).content + result = await async_jigsaw.vision.vocr(blob_content, test_case.get("options", {})) + else: + # Use params directly + result = await async_jigsaw.vision.vocr(test_case["params"]) + + print(f"Test {test_case['name']}: Success={result.get('success')}") + + # Verify response structure + assert result["success"] is True + if "prompt" in (test_case.get("params") or {}): + assert "context" in result + assert "total_pages" in result # PDF specific + + # Check if page_range is in response when requested + if test_case.get("params", {}).get("page_range") or test_case.get("options", {}).get( + "page_range" + ): + assert "page_range" in result + assert isinstance(result["page_range"], list) + + logger.info(f"Test {test_case['name']}: total_pages={result.get('total_pages')}") + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}")