diff --git a/.github/ruff.toml b/.github/ruff.toml new file mode 100644 index 0000000..3922be1 --- /dev/null +++ b/.github/ruff.toml @@ -0,0 +1,16 @@ +# Ruff configuration for CI/CD +line-length = 100 +target-version = "py37" + +[lint] +select = [ + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "UP", # pyupgrade +] +ignore = [ + "B008", # do not perform function calls in argument defaults +] diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..b1f5b26 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,93 @@ +name: CI + +on: + push: + branches: [ main, develop ] + pull_request: + branches: [ main, develop ] + +jobs: + ruff-format-check: + name: Ruff Format Check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + + - name: Install ruff + run: pip install ruff + + - name: Check all files with ruff + run: | + ruff check jigsawstack/ --config .github/ruff.toml + ruff format --check jigsawstack/ --config .github/ruff.toml + + test: + name: Test - ${{ matrix.test-file }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + test-file: + - test_audio.py + - test_classification.py + - test_embedding.py + - test_file_store.py + - test_image_generation.py + - test_object_detection.py + - test_prediction.py + - test_sentiment.py + - test_sql.py + - test_summary.py + - test_translate.py + - test_validate.py + - test_web.py + - test_deep_research.py + - test_ai_scrape.py + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install pytest pytest-asyncio pytest-cov python-dotenv + pip install -e . + + - name: Run test ${{ matrix.test-file }} + env: + JIGSAWSTACK_API_KEY: ${{ secrets.JIGSAWSTACK_API_KEY }} + run: | + pytest tests/${{ matrix.test-file }} -v + + all-checks-passed: + name: All Checks Passed + needs: [ruff-format-check, test] + runs-on: ubuntu-latest + if: always() + steps: + - name: Verify all checks passed + run: | + echo "Ruff Format Check: ${{ needs.ruff-format-check.result }}" + echo "Tests: ${{ needs.test.result }}" + + if [[ "${{ needs.ruff-format-check.result }}" != "success" ]]; then + echo "❌ Ruff format check failed" + exit 1 + fi + + if [[ "${{ needs.test.result }}" != "success" ]]; then + echo "❌ Tests failed" + exit 1 + fi + + echo "✅ All checks passed successfully!" \ No newline at end of file diff --git a/README.md b/README.md index edf4020..e13c6bf 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ VOCR: ```py params = { - "url": "https://rogilvkqloanxtvjfrkm.supabase.co/storage/v1/object/public/demo/Collabo%201080x842.jpg?t=2024-03-22T09%3A22%3A48.442Z" + "url": "https://images.unsplash.com/photo-1542931287-023b922fa89b?q=80&w=2574&auto=format&fit=crop&ixlib=rb-4.1.0&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D?t=2024-03-22T09%3A22%3A48.442Z" } result = jigsaw.vision.vocr(params) ``` diff --git a/biome.json b/biome.json deleted file mode 100644 index 5ad6df5..0000000 --- a/biome.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "$schema": "https://biomejs.dev/schemas/1.9.4/schema.json", - "files": { - "ignoreUnknown": false, - "ignore": [] - }, - "formatter": { - "enabled": true, - "useEditorconfig": true, - "formatWithErrors": false, - "indentStyle": "space", - "indentWidth": 2, - "lineEnding": "lf", - "lineWidth": 150, - "attributePosition": "auto", - "bracketSpacing": true - }, - "organizeImports": { - "enabled": true - }, - "linter": { - "enabled": false - }, - "javascript": { - "formatter": { - "jsxQuoteStyle": "double", - "quoteProperties": "asNeeded", - "trailingCommas": "es5", - "semicolons": "always", - "arrowParentheses": "always", - "bracketSameLine": false, - "quoteStyle": "double", - "attributePosition": "auto", - "bracketSpacing": true - } - } -} diff --git a/jigsawstack/__init__.py b/jigsawstack/__init__.py index e860936..091f775 100644 --- a/jigsawstack/__init__.py +++ b/jigsawstack/__init__.py @@ -1,22 +1,23 @@ -from typing import Union, Dict import os -from .audio import Audio, AsyncAudio -from .vision import Vision, AsyncVision +from typing import Dict, Union + +from .audio import AsyncAudio, Audio +from .classification import AsyncClassification, Classification +from .embedding import AsyncEmbedding, Embedding +from .embedding_v2 import AsyncEmbeddingV2, EmbeddingV2 +from .exceptions import JigsawStackError +from .image_generation import AsyncImageGeneration, ImageGeneration +from .prediction import AsyncPrediction, Prediction +from .prompt_engine import AsyncPromptEngine, PromptEngine from .search import Search -from .prediction import Prediction, AsyncPrediction +from .sentiment import AsyncSentiment, Sentiment from .sql import SQL, AsyncSQL -from .store import Store, AsyncStore -from .translate import Translate, AsyncTranslate -from .web import Web, AsyncWeb -from .sentiment import Sentiment, AsyncSentiment -from .validate import Validate, AsyncValidate -from .summary import Summary, AsyncSummary -from .embedding import Embedding, AsyncEmbedding -from .exceptions import JigsawStackError -from .image_generation import ImageGeneration, AsyncImageGeneration -from .classification import Classification, AsyncClassification -from .prompt_engine import PromptEngine, AsyncPromptEngine -from .embeddingV2 import EmbeddingV2, AsyncEmbeddingV2 +from .store import AsyncStore, Store +from .summary import AsyncSummary, Summary +from .translate import AsyncTranslate, Translate +from .validate import AsyncValidate, Validate +from .vision import AsyncVision, Vision +from .web import AsyncWeb, Web class JigsawStack: @@ -51,7 +52,7 @@ def __init__( if api_url is None: api_url = os.environ.get("JIGSAWSTACK_API_URL") if api_url is None: - api_url = f"https://api.jigsawstack.com/" + api_url = "https://api.jigsawstack.com/" self.api_key = api_key self.api_url = api_url @@ -171,7 +172,7 @@ def __init__( if api_url is None: api_url = os.environ.get("JIGSAWSTACK_API_URL") if api_url is None: - api_url = f"https://api.jigsawstack.com/" + api_url = "https://api.jigsawstack.com/" self.api_key = api_key self.api_url = api_url diff --git a/jigsawstack/async_request.py b/jigsawstack/async_request.py index 033b39b..26a7e53 100644 --- a/jigsawstack/async_request.py +++ b/jigsawstack/async_request.py @@ -1,8 +1,11 @@ -from typing import Any, Dict, Generic, List, Union, cast, TypedDict, AsyncGenerator +import json +from io import BytesIO +from typing import Any, AsyncGenerator, Dict, Generic, List, TypedDict, Union, cast + import aiohttp from typing_extensions import Literal, TypeVar + from .exceptions import NoContentError, raise_for_code_and_type -import json RequestVerb = Literal["get", "post", "put", "patch", "delete"] @@ -22,7 +25,7 @@ def __init__( path: str, params: Union[Dict[Any, Any], List[Dict[Any, Any]]], verb: RequestVerb, - headers: Dict[str, str] = {"Content-Type": "application/json"}, + headers: Dict[str, str] = None, data: Union[bytes, None] = None, stream: Union[bool, None] = False, ): @@ -32,7 +35,7 @@ def __init__( self.api_url = config.get("api_url") self.api_key = config.get("api_key") self.data = data - self.headers = headers + self.headers = headers or {"Content-Type": "application/json"} self.disable_request_logging = config.get("disable_request_logging") self.stream = stream @@ -243,12 +246,27 @@ async def make_request( ) else: if data is not None: + form_data = aiohttp.FormData() + form_data.add_field( + "file", + BytesIO(data), + content_type=headers.get("Content-Type", "application/octet-stream"), + filename="file", + ) + + if self.params and isinstance(self.params, dict): + form_data.add_field( + "body", json.dumps(self.params), content_type="application/json" + ) + + multipart_headers = headers.copy() + multipart_headers.pop("Content-Type", None) + return await session.request( verb, url, - data=data, - params=converted_params, # Use converted params - headers=headers, + data=form_data, + headers=multipart_headers, ) else: return await session.request( diff --git a/jigsawstack/audio.py b/jigsawstack/audio.py index 2046c58..cadfd25 100644 --- a/jigsawstack/audio.py +++ b/jigsawstack/audio.py @@ -1,13 +1,11 @@ -from typing import Any, Dict, List, cast, Union, Optional, overload -from typing_extensions import NotRequired, TypedDict -from .request import Request, RequestConfig -from .async_request import AsyncRequest, AsyncRequestConfig +from typing import Any, Dict, List, Optional, Union, cast, overload + +from typing_extensions import Literal, NotRequired, TypedDict + from ._config import ClientConfig -from typing import Any, Dict, List, cast -from typing_extensions import NotRequired, TypedDict, Literal -from .custom_typing import SupportedAccents -from .helpers import build_path from ._types import BaseResponse +from .async_request import AsyncRequest, AsyncRequestConfig +from .request import Request, RequestConfig class SpeechToTextParams(TypedDict): @@ -80,22 +78,21 @@ def speech_to_text( blob: Union[SpeechToTextParams, bytes], options: Optional[SpeechToTextParams] = None, ) -> Union[SpeechToTextResponse, SpeechToTextWebhookResponse]: + options = options or {} + path = "/ai/transcribe" + content_type = options.get("content_type", "application/octet-stream") + headers = {"Content-Type": content_type} if isinstance( blob, dict ): # If params is provided as a dict, we assume it's the first argument resp = Request( config=self.config, - path="/ai/transcribe", + path=path, params=cast(Dict[Any, Any], blob), verb="post", ).perform_with_content() return resp - options = options or {} - path = build_path(base_path="/ai/transcribe", params=options) - content_type = options.get("content_type", "application/octet-stream") - headers = {"Content-Type": content_type} - resp = Request( config=self.config, path=path, @@ -137,20 +134,19 @@ async def speech_to_text( blob: Union[SpeechToTextParams, bytes], options: Optional[SpeechToTextParams] = None, ) -> Union[SpeechToTextResponse, SpeechToTextWebhookResponse]: + options = options or {} + path = "/ai/transcribe" + content_type = options.get("content_type", "application/octet-stream") + headers = {"Content-Type": content_type} if isinstance(blob, dict): resp = await AsyncRequest( config=self.config, - path="/ai/transcribe", + path=path, params=cast(Dict[Any, Any], blob), verb="post", ).perform_with_content() return resp - options = options or {} - path = build_path(base_path="/ai/transcribe", params=options) - content_type = options.get("content_type", "application/octet-stream") - headers = {"Content-Type": content_type} - resp = await AsyncRequest( config=self.config, path=path, diff --git a/jigsawstack/classification.py b/jigsawstack/classification.py index a53ed87..45407e9 100644 --- a/jigsawstack/classification.py +++ b/jigsawstack/classification.py @@ -1,9 +1,11 @@ from typing import Any, Dict, List, Union, cast -from typing_extensions import NotRequired, TypedDict, Literal -from .request import Request, RequestConfig -from .async_request import AsyncRequest, AsyncRequestConfig + +from typing_extensions import Literal, NotRequired, TypedDict + from ._config import ClientConfig from ._types import BaseResponse +from .async_request import AsyncRequest, AsyncRequestConfig +from .request import Request, RequestConfig class DatasetItem(TypedDict): diff --git a/jigsawstack/custom_typing.py b/jigsawstack/custom_typing.py deleted file mode 100644 index e77adde..0000000 --- a/jigsawstack/custom_typing.py +++ /dev/null @@ -1,574 +0,0 @@ -from typing import Literal - -SupportedAccents = Literal[ - "af-ZA-female-1", - "af-ZA-male-1", - "am-ET-female-1", - "am-ET-male-1", - "ar-AE-female-1", - "ar-AE-male-1", - "ar-BH-female-1", - "ar-BH-male-1", - "ar-DZ-female-1", - "ar-DZ-male-1", - "ar-EG-female-1", - "ar-EG-male-1", - "ar-IQ-female-1", - "ar-IQ-male-1", - "ar-JO-female-1", - "ar-JO-male-1", - "ar-KW-female-1", - "ar-KW-male-1", - "ar-LB-female-1", - "ar-LB-male-1", - "ar-LY-female-1", - "ar-LY-male-1", - "ar-MA-female-1", - "ar-MA-male-1", - "ar-OM-female-1", - "ar-OM-male-1", - "ar-QA-female-1", - "ar-QA-male-1", - "ar-SA-female-1", - "ar-SA-male-1", - "ar-SY-female-1", - "ar-SY-male-1", - "ar-TN-female-1", - "ar-TN-male-1", - "ar-YE-female-1", - "ar-YE-male-1", - "as-IN-male-1", - "as-IN-female-1", - "az-AZ-female-1", - "az-AZ-male-1", - "bg-BG-female-1", - "bg-BG-male-1", - "bn-BD-female-1", - "bn-BD-male-1", - "bn-IN-female-1", - "bn-IN-male-1", - "bs-BA-female-1", - "bs-BA-male-1", - "ca-ES-female-1", - "ca-ES-male-1", - "ca-ES-female-2", - "cs-CZ-female-1", - "cs-CZ-male-1", - "cy-GB-female-1", - "cy-GB-male-1", - "da-DK-female-1", - "da-DK-male-1", - "de-AT-female-1", - "de-AT-male-1", - "de-CH-female-1", - "de-CH-male-1", - "de-DE-female-1", - "de-DE-male-1", - "de-DE-female-2", - "de-DE-male-2", - "de-DE-male-3", - "de-DE-female-3", - "de-DE-male-4", - "de-DE-male-5", - "de-DE-female-4", - "de-DE-male-6", - "de-DE-male-7", - "de-DE-female-5", - "de-DE-male-8", - "de-DE-female-6", - "de-DE-female-7", - "de-DE-male-9", - "de-DE-female-8", - "de-DE-female-9", - "de-DE-female-10", - "el-GR-female-2", - "el-GR-male-2", - "en-AU-female-2", - "en-AU-male-2", - "en-AU-female-3", - "en-AU-female-4", - "en-AU-male-3", - "en-AU-male-4", - "en-AU-female-5", - "en-AU-female-6", - "en-AU-female-7", - "en-AU-male-5", - "en-AU-female-8", - "en-AU-male-6", - "en-AU-male-7", - "en-AU-female-9", - "en-CA-female-2", - "en-CA-male-2", - "en-GB-female-2", - "en-GB-male-2", - "en-GB-female-3", - "en-GB-female-4", - "en-GB-male-3", - "en-GB-female-5", - "en-GB-male-4", - "en-GB-male-5", - "en-GB-female-6", - "en-GB-female-7", - "en-GB-male-6", - "en-GB-male-7", - "en-GB-female-8", - "en-GB-male-8", - "en-GB-female-9", - "en-GB-female-10", - "en-GB-male-9", - "en-GB-male-10", - "en-GB-female-11", - "en-HK-female-1", - "en-HK-male-1", - "en-IE-female-3", - "en-IE-male-3", - "en-IN-female-3", - "en-IN-male-3", - "en-IN-male-4", - "en-IN-female-4", - "en-IN-female-5", - "en-IN-female-6", - "en-IN-male-5", - "en-IN-male-6", - "en-KE-female-1", - "en-KE-male-1", - "en-NG-female-1", - "en-NG-male-1", - "en-NZ-female-1", - "en-NZ-male-1", - "en-PH-female-1", - "en-PH-male-1", - "en-SG-female-1", - "en-SG-male-1", - "en-TZ-female-1", - "en-TZ-male-1", - "en-US-female-3", - "en-US-female-4", - "en-US-male-3", - "en-US-male-4", - "en-US-female-5", - "en-US-female-6", - "en-US-male-5", - "en-US-male-6", - "en-US-female-7", - "en-US-male-7", - "en-US-female-8", - "en-US-male-8", - "en-US-female-9", - "en-US-male-9", - "en-US-female-10", - "en-US-male-10", - "en-US-female-11", - "en-US-male-11", - "en-US-female-12", - "en-US-male-12", - "en-US-female-13", - "en-US-female-14", - "en-US-female-15", - "en-US-female-16", - "en-US-male-13", - "en-US-male-14", - "en-US-female-17", - "en-US-female-18", - "en-US-male-15", - "en-US-male-16", - "en-US-female-19", - "en-US-female-20", - "en-US-female-21", - "en-US-female-22", - "en-US-male-17", - "en-US-male-18", - "en-US-male-19", - "en-US-male-20", - "en-US-male-21", - "en-US-female-23", - "en-US-male-22", - "en-US-male-23", - "en-US-neutral-1", - "en-US-male-24", - "en-US-male-25", - "en-US-male-26", - "en-US-male-27", - "en-US-female-24", - "en-US-female-25", - "en-US-female-26", - "en-US-female-27", - "en-US-male-28", - "en-US-female-28", - "en-US-female-29", - "en-US-female-30", - "en-US-male-29", - "en-US-male-30", - "en-ZA-female-1", - "en-ZA-male-1", - "es-AR-female-1", - "es-AR-male-1", - "es-BO-female-1", - "es-BO-male-1", - "es-CL-female-1", - "es-CL-male-1", - "es-CO-female-1", - "es-CO-male-1", - "es-CR-female-1", - "es-CR-male-1", - "es-CU-female-1", - "es-CU-male-1", - "es-DO-female-1", - "es-DO-male-1", - "es-EC-female-1", - "es-EC-male-1", - "es-ES-female-9", - "es-ES-male-10", - "es-ES-female-10", - "es-ES-male-11", - "es-ES-male-12", - "es-ES-male-13", - "es-ES-female-11", - "es-ES-female-12", - "es-ES-female-13", - "es-ES-female-14", - "es-ES-male-14", - "es-ES-male-15", - "es-ES-male-16", - "es-ES-female-15", - "es-ES-female-16", - "es-ES-female-17", - "es-ES-female-18", - "es-ES-female-19", - "es-ES-female-20", - "es-ES-female-21", - "es-ES-male-17", - "es-ES-male-18", - "es-ES-female-22", - "es-ES-female-23", - "es-GQ-female-1", - "es-GQ-male-1", - "es-GT-female-1", - "es-GT-male-1", - "es-HN-female-1", - "es-HN-male-1", - "es-MX-female-12", - "es-MX-male-11", - "es-MX-female-13", - "es-MX-female-14", - "es-MX-female-15", - "es-MX-male-12", - "es-MX-male-13", - "es-MX-female-16", - "es-MX-male-14", - "es-MX-male-15", - "es-MX-female-17", - "es-MX-female-18", - "es-MX-male-16", - "es-MX-female-19", - "es-MX-male-17", - "es-NI-female-1", - "es-NI-male-1", - "es-PA-female-1", - "es-PA-male-1", - "es-PE-female-1", - "es-PE-male-1", - "es-PR-female-1", - "es-PR-male-1", - "es-PY-female-1", - "es-PY-male-1", - "es-SV-female-1", - "es-SV-male-1", - "es-US-female-1", - "es-US-male-1", - "es-UY-female-1", - "es-UY-male-1", - "es-VE-female-1", - "es-VE-male-1", - "et-EE-female-11", - "et-EE-male-10", - "eu-ES-female-11", - "eu-ES-male-10", - "fa-IR-female-11", - "fa-IR-male-10", - "fi-FI-female-12", - "fi-FI-male-11", - "fi-FI-female-13", - "fil-PH-female-11", - "fil-PH-male-10", - "fr-BE-female-12", - "fr-BE-male-11", - "fr-CA-female-12", - "fr-CA-male-11", - "fr-CA-male-12", - "fr-CA-male-13", - "fr-CH-female-12", - "fr-CH-male-11", - "fr-FR-female-12", - "fr-FR-male-11", - "fr-FR-male-12", - "fr-FR-female-13", - "fr-FR-female-14", - "fr-FR-male-13", - "fr-FR-female-15", - "fr-FR-female-16", - "fr-FR-female-17", - "fr-FR-male-14", - "fr-FR-female-18", - "fr-FR-male-15", - "fr-FR-male-16", - "fr-FR-male-17", - "fr-FR-female-19", - "fr-FR-female-20", - "fr-FR-male-18", - "fr-FR-female-21", - "fr-FR-male-19", - "fr-FR-male-20", - "ga-IE-female-12", - "ga-IE-male-12", - "gl-ES-female-12", - "gl-ES-male-12", - "gu-IN-female-1", - "gu-IN-male-1", - "he-IL-female-12", - "he-IL-male-12", - "hi-IN-female-13", - "hi-IN-male-13", - "hi-IN-male-14", - "hi-IN-female-14", - "hi-IN-female-15", - "hi-IN-male-15", - "hi-IN-male-16", - "hr-HR-female-12", - "hr-HR-male-12", - "hu-HU-female-13", - "hu-HU-male-13", - "hy-AM-female-12", - "hy-AM-male-12", - "id-ID-female-13", - "id-ID-male-13", - "is-IS-female-12", - "is-IS-male-12", - "it-IT-female-13", - "it-IT-female-14", - "it-IT-male-13", - "it-IT-male-14", - "it-IT-male-15", - "it-IT-male-16", - "it-IT-female-15", - "it-IT-female-16", - "it-IT-male-17", - "it-IT-male-18", - "it-IT-female-17", - "it-IT-female-18", - "it-IT-male-19", - "it-IT-female-19", - "it-IT-female-20", - "it-IT-male-20", - "it-IT-male-21", - "it-IT-male-22", - "it-IT-male-23", - "it-IT-male-24", - "it-IT-female-21", - "it-IT-female-22", - "it-IT-male-25", - "it-IT-male-26", - "iu-Cans-CA-female-1", - "iu-Cans-CA-male-1", - "iu-Latn-CA-female-1", - "iu-Latn-CA-male-1", - "ja-JP-female-14", - "ja-JP-male-16", - "ja-JP-female-15", - "ja-JP-male-17", - "ja-JP-female-16", - "ja-JP-male-18", - "ja-JP-female-17", - "ja-JP-male-19", - "ja-JP-male-20", - "jv-ID-female-13", - "jv-ID-male-16", - "ka-GE-female-13", - "ka-GE-male-16", - "kk-KZ-female-13", - "kk-KZ-male-16", - "km-KH-female-13", - "km-KH-male-16", - "kn-IN-female-13", - "kn-IN-male-16", - "ko-KR-female-14", - "ko-KR-male-17", - "ko-KR-male-18", - "ko-KR-male-19", - "ko-KR-male-20", - "ko-KR-female-15", - "ko-KR-female-16", - "ko-KR-female-17", - "ko-KR-female-18", - "ko-KR-male-21", - "ko-KR-male-22", - "lo-LA-female-13", - "lo-LA-male-17", - "lt-LT-female-13", - "lt-LT-male-17", - "lv-LV-female-13", - "lv-LV-male-17", - "mk-MK-female-13", - "mk-MK-male-17", - "ml-IN-female-13", - "ml-IN-male-17", - "mn-MN-female-13", - "mn-MN-male-17", - "mr-IN-female-1", - "mr-IN-male-1", - "ms-MY-female-13", - "ms-MY-male-17", - "mt-MT-female-13", - "mt-MT-male-17", - "my-MM-female-13", - "my-MM-male-17", - "nb-NO-female-14", - "nb-NO-male-18", - "nb-NO-female-15", - "ne-NP-female-13", - "ne-NP-male-17", - "nl-BE-female-14", - "nl-BE-male-18", - "nl-NL-female-14", - "nl-NL-male-18", - "nl-NL-female-15", - "or-IN-female-1", - "or-IN-male-1", - "pa-IN-male-1", - "pa-IN-female-1", - "pl-PL-female-14", - "pl-PL-male-18", - "pl-PL-female-15", - "ps-AF-female-13", - "ps-AF-male-17", - "pt-BR-female-14", - "pt-BR-male-18", - "pt-BR-female-15", - "pt-BR-male-19", - "pt-BR-female-16", - "pt-BR-male-20", - "pt-BR-female-17", - "pt-BR-male-21", - "pt-BR-male-22", - "pt-BR-female-18", - "pt-BR-female-19", - "pt-BR-female-20", - "pt-BR-male-23", - "pt-BR-female-21", - "pt-BR-male-24", - "pt-BR-female-22", - "pt-BR-male-25", - "pt-BR-male-26", - "pt-BR-female-23", - "pt-BR-female-24", - "pt-PT-female-15", - "pt-PT-male-19", - "pt-PT-female-16", - "ro-RO-female-14", - "ro-RO-male-18", - "ru-RU-female-15", - "ru-RU-male-19", - "ru-RU-female-16", - "si-LK-female-14", - "si-LK-male-18", - "sk-SK-female-14", - "sk-SK-male-18", - "sl-SI-female-14", - "sl-SI-male-18", - "so-SO-female-14", - "so-SO-male-18", - "sq-AL-female-14", - "sq-AL-male-18", - "sr-Latn-RS-male-1", - "sr-Latn-RS-female-1", - "sr-RS-female-14", - "sr-RS-male-18", - "su-ID-female-14", - "su-ID-male-18", - "sv-SE-female-15", - "sv-SE-male-19", - "sv-SE-female-16", - "sw-KE-female-14", - "sw-KE-male-18", - "sw-TZ-female-1", - "sw-TZ-male-1", - "ta-IN-female-14", - "ta-IN-male-18", - "ta-LK-female-1", - "ta-LK-male-1", - "ta-MY-female-1", - "ta-MY-male-1", - "ta-SG-female-1", - "ta-SG-male-1", - "te-IN-female-14", - "te-IN-male-18", - "th-TH-female-15", - "th-TH-male-19", - "th-TH-female-16", - "tr-TR-female-15", - "tr-TR-male-19", - "uk-UA-female-14", - "uk-UA-male-18", - "ur-IN-female-1", - "ur-IN-male-1", - "ur-PK-female-14", - "ur-PK-male-18", - "uz-UZ-female-14", - "uz-UZ-male-18", - "vi-VN-female-14", - "vi-VN-male-18", - "wuu-CN-female-1", - "wuu-CN-male-1", - "yue-CN-female-1", - "yue-CN-male-1", - "zh-CN-female-15", - "zh-CN-male-19", - "zh-CN-male-20", - "zh-CN-female-16", - "zh-CN-male-21", - "zh-CN-female-17", - "zh-CN-female-18", - "zh-CN-female-19", - "zh-CN-female-20", - "zh-CN-female-21", - "zh-CN-female-22", - "zh-CN-female-23", - "zh-CN-female-24", - "zh-CN-female-25", - "zh-CN-female-26", - "zh-CN-female-27", - "zh-CN-female-28", - "zh-CN-female-29", - "zh-CN-female-30", - "zh-CN-female-31", - "zh-CN-female-32", - "zh-CN-female-33", - "zh-CN-female-34", - "zh-CN-male-22", - "zh-CN-male-23", - "zh-CN-male-24", - "zh-CN-male-25", - "zh-CN-male-26", - "zh-CN-male-27", - "zh-CN-male-28", - "zh-CN-male-29", - "zh-CN-male-30", - "zh-CN-male-31", - "zh-CN-male-32", - "zh-CN-male-33", - "zh-CN-guangxi-male-1", - "zh-CN-henan-male-1", - "zh-CN-liaoning-female-2", - "zh-CN-liaoning-male-1", - "zh-CN-shaanxi-female-2", - "zh-CN-shandong-male-1", - "zh-CN-sichuan-male-1", - "zh-HK-female-18", - "zh-HK-male-22", - "zh-HK-female-19", - "zh-TW-female-19", - "zh-TW-male-22", - "zh-TW-female-20", - "zu-ZA-female-17", - "zu-ZA-male-21", -] diff --git a/jigsawstack/embedding.py b/jigsawstack/embedding.py index 70a8359..cd755f0 100644 --- a/jigsawstack/embedding.py +++ b/jigsawstack/embedding.py @@ -1,11 +1,12 @@ -from typing import Any, Dict, List, Union, cast, Literal, overload +from typing import Any, Dict, List, Literal, Union, cast, overload + from typing_extensions import NotRequired, TypedDict -from .request import Request, RequestConfig -from .async_request import AsyncRequest -from typing import List, Union + from ._config import ClientConfig -from .helpers import build_path from ._types import BaseResponse +from .async_request import AsyncRequest +from .helpers import build_path +from .request import Request, RequestConfig class EmbeddingParams(TypedDict): @@ -46,9 +47,7 @@ def __init__( @overload def execute(self, params: EmbeddingParams) -> EmbeddingResponse: ... @overload - def execute( - self, blob: bytes, options: EmbeddingParams = None - ) -> EmbeddingResponse: ... + def execute(self, blob: bytes, options: EmbeddingParams = None) -> EmbeddingResponse: ... def execute( self, @@ -100,9 +99,7 @@ def __init__( @overload async def execute(self, params: EmbeddingParams) -> EmbeddingResponse: ... @overload - async def execute( - self, blob: bytes, options: EmbeddingParams = None - ) -> EmbeddingResponse: ... + async def execute(self, blob: bytes, options: EmbeddingParams = None) -> EmbeddingResponse: ... async def execute( self, diff --git a/jigsawstack/embeddingV2.py b/jigsawstack/embedding_v2.py similarity index 92% rename from jigsawstack/embeddingV2.py rename to jigsawstack/embedding_v2.py index d7559bb..fe62f69 100644 --- a/jigsawstack/embeddingV2.py +++ b/jigsawstack/embedding_v2.py @@ -1,11 +1,12 @@ -from typing import Any, Dict, List, Union, cast, Literal, overload +from typing import Any, Dict, List, Literal, Union, cast, overload + from typing_extensions import NotRequired, TypedDict -from .request import Request, RequestConfig -from .async_request import AsyncRequest -from typing import List, Union + from ._config import ClientConfig -from .helpers import build_path +from .async_request import AsyncRequest from .embedding import Chunk +from .helpers import build_path +from .request import Request, RequestConfig class EmbeddingV2Params(TypedDict): @@ -14,7 +15,7 @@ class EmbeddingV2Params(TypedDict): type: Literal["text", "text-other", "image", "audio", "pdf"] url: NotRequired[str] file_store_key: NotRequired[str] - token_overflow_mode: NotRequired[Literal["truncate", "chunk", "error"]] = "chunk" + token_overflow_mode: NotRequired[Literal["truncate", "error"]] speaker_fingerprint: NotRequired[bool] @@ -44,9 +45,7 @@ def __init__( @overload def execute(self, params: EmbeddingV2Params) -> EmbeddingV2Response: ... @overload - def execute( - self, blob: bytes, options: EmbeddingV2Params = None - ) -> EmbeddingV2Response: ... + def execute(self, blob: bytes, options: EmbeddingV2Params = None) -> EmbeddingV2Response: ... def execute( self, diff --git a/jigsawstack/geo.py b/jigsawstack/geo.py deleted file mode 100644 index cd182ba..0000000 --- a/jigsawstack/geo.py +++ /dev/null @@ -1,406 +0,0 @@ -from typing import Any, Dict, List, Union, cast -from typing_extensions import NotRequired, TypedDict -from .request import Request, RequestConfig -from .async_request import AsyncRequestConfig, AsyncRequest -from typing import List, Union -from ._config import ClientConfig - - -class BaseResponse: - success: bool - - -class GeoParams(TypedDict): - search_value: str - lat: str - lng: str - country_code: str - proximity_lat: str - proximity_lng: str - types: str - city_code: str - state_code: str - limit: int - - -class GeoSearchParams(TypedDict): - search_value: str - country_code: NotRequired[str] = None - proximity_lat: NotRequired[str] = None - proximity_lng: NotRequired[str] = None - types: NotRequired[str] = None - - -class Geoloc(TypedDict): - type: str - coordinates: List[float] - - -class Region(TypedDict): - name: str - region_code: str - region_code_full: str - - -class Country(TypedDict): - name: str - country_code: str - country_code_alpha_3: str - - -class GeoSearchResult(TypedDict): - type: str - full_address: str - name: str - place_formatted: str - postcode: str - place: str - region: Region - country: Country - language: str - geoloc: Geoloc - poi_category: List[str] - additional_properties: Dict[str, any] - - -class CityResult(TypedDict): - state_code: str - name: str - city_code: str - state: "StateResult" - - -class CountryResult(TypedDict): - country_code: str - name: str - iso2: str - iso3: str - capital: str - phone_code: str - region: str - subregion: str - currency_code: str - geoloc: Geoloc - currency_name: str - currency_symbol: str - tld: str - native: str - emoji: str - emojiU: str - latitude: float - longitude: float - - -class StateResult(TypedDict): - state_code: str - name: str - country_code: str - country: CountryResult - - -class GeoSearchResponse(BaseResponse): - data: List[GeoSearchResult] - - -class GeocodeParams(TypedDict): - search_value: str - lat: str - lng: str - country_code: str - proximity_lat: str - proximity_lng: str - types: str - limit: int - - -class GeoCityParams(TypedDict): - country_code: str - city_code: str - state_code: str - search_value: str - lat: str - lng: str - limit: int - - -class GeoCityResponse(BaseResponse): - city: List[CityResult] - - -class GeoCountryParams(TypedDict): - country_code: str - city_code: str - search_value: str - lat: str - lng: str - limit: int - currency_code: str - - -class GeoCountryResponse(BaseResponse): - country: List[CountryResult] - - -class GeoStateParams(TypedDict): - country_code: str - state_code: str - search_value: str - lat: str - lng: str - limit: int - - -class GeoStateResponse(BaseResponse): - state: List[StateResult] - - -class GeoDistanceParams(TypedDict): - unit: NotRequired[str] = None # "K" or "N" - lat1: str - lng1: str - lat2: str - lng2: str - - -class GeoDistanceResponse(BaseResponse): - distance: float - - -class GeoTimezoneParams(TypedDict): - lat: str - lng: str - city_code: NotRequired[str] = None - country_code: NotRequired[str] = None - - -class GeoTimezoneResponse(BaseResponse): - timezone: Dict[str, any] - - -class GeohashParams(TypedDict): - lat: str - lng: str - precision: int - - -class GeohashResponse(BaseResponse): - geohash: str - - -class GeohashDecodeResponse(BaseResponse): - latitude: float - longitude: float - - -class Geo(ClientConfig): - config: RequestConfig - - def __init__( - self, - api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, - ): - super().__init__(api_key, api_url, disable_request_logging) - self.config = RequestConfig( - api_url=api_url, - api_key=api_key, - disable_request_logging=disable_request_logging, - ) - - def search(self, params: GeoSearchParams) -> GeoSearchResponse: - path = "/geo/search" - resp = Request( - config=self.config, - path=path, - params=cast(Dict[Any, Any], params), - verb="get", - ).perform_with_content() - return resp - - def geocode(self, params: GeocodeParams) -> GeohashDecodeResponse: - path = "/geo/geocode" - resp = Request( - config=self.config, - path=path, - params=cast(Dict[Any, Any], params), - verb="get", - ).perform_with_content() - return resp - - def city(self, params: GeoCityParams) -> GeoCityResponse: - path = "/geo/city" - resp = Request( - config=self.config, - path=path, - params=cast(Dict[Any, Any], params), - verb="get", - ).perform_with_content() - return resp - - def country(self, params: GeoCountryParams) -> GeoCountryResponse: - path = "/geo/country" - resp = Request( - config=self.config, - path=path, - params=cast(Dict[Any, Any], params), - verb="get", - ).perform_with_content() - return resp - - def state(self, params: GeoStateParams) -> GeoStateResponse: - path = "/geo/state" - resp = Request( - config=self.config, - path=path, - params=cast(Dict[Any, Any], params), - verb="get", - ).perform_with_content() - return resp - - def distance(self, params: GeoDistanceParams) -> GeoDistanceResponse: - path = "/geo/distance" - resp = Request( - config=self.config, - path=path, - params=cast(Dict[Any, Any], params), - verb="get", - ).perform_with_content() - return resp - - def timezone(self, params: GeoTimezoneParams) -> GeoTimezoneResponse: - path = "/geo/timezone" - resp = Request( - config=self.config, - path=path, - params=cast(Dict[Any, Any], params), - verb="get", - ).perform_with_content() - return resp - - def geohash(self, params: GeohashParams) -> GeohashResponse: - path = "/geo/geohash" - resp = Request( - config=self.config, - path=path, - params=cast(Dict[Any, Any], params), - verb="get", - ).perform_with_content() - return resp - - def geohash(self, key: str) -> GeohashDecodeResponse: - path = f"/geo/geohash/decode/{key}" - resp = Request( - config=self.config, - path=path, - params=cast(Dict[Any, Any], params={}), - verb="get", - ).perform_with_content() - return resp - - -class AsyncGeo(ClientConfig): - config: AsyncRequestConfig - - def __init__( - self, - api_key: str, - api_url: str, - disable_request_logging: Union[bool, None] = False, - ): - super().__init__(api_key, api_url, disable_request_logging) - self.config = AsyncRequestConfig( - api_url=api_url, - api_key=api_key, - disable_request_logging=disable_request_logging, - ) - - async def search(self, params: GeoSearchParams) -> GeoSearchResponse: - path = "/geo/search" - resp = await AsyncRequest( - config=self.config, - path=path, - params=cast(Dict[Any, Any], params), - verb="get", - ).perform_with_content() - return resp - - async def geocode(self, params: GeocodeParams) -> GeohashDecodeResponse: - path = "/geo/geocode" - resp = await AsyncRequest( - config=self.config, - path=path, - params=cast(Dict[Any, Any], params), - verb="get", - ).perform_with_content() - return resp - - async def city(self, params: GeoCityParams) -> GeoCityResponse: - path = "/geo/city" - resp = await AsyncRequest( - config=self.config, - path=path, - params=cast(Dict[Any, Any], params), - verb="get", - ).perform_with_content() - return resp - - async def country(self, params: GeoCountryParams) -> GeoCountryResponse: - path = "/geo/country" - resp = await AsyncRequest( - config=self.config, - path=path, - params=cast(Dict[Any, Any], params), - verb="get", - ).perform_with_content() - return resp - - async def state(self, params: GeoStateParams) -> GeoStateResponse: - path = "/geo/state" - resp = await AsyncRequest( - config=self.config, - path=path, - params=cast(Dict[Any, Any], params), - verb="get", - ).perform_with_content() - return resp - - async def distance(self, params: GeoDistanceParams) -> GeoDistanceResponse: - path = "/geo/distance" - resp = await AsyncRequest( - config=self.config, - path=path, - params=cast(Dict[Any, Any], params), - verb="get", - ).perform_with_content() - return resp - - async def timezone(self, params: GeoTimezoneParams) -> GeoTimezoneResponse: - path = "/geo/timezone" - resp = await AsyncRequest( - config=self.config, - path=path, - params=cast(Dict[Any, Any], params), - verb="get", - ).perform_with_content() - return resp - - async def geohash(self, params: GeohashParams) -> GeohashResponse: - path = "/geo/geohash" - resp = await AsyncRequest( - config=self.config, - path=path, - params=cast(Dict[Any, Any], params), - verb="get", - ).perform_with_content() - return resp - - async def geohash(self, key: str) -> GeohashDecodeResponse: - path = f"/geo/geohash/decode/{key}" - resp = await AsyncRequest( - config=self.config, - path=path, - params=cast(Dict[Any, Any], params={}), - verb="get", - ).perform_with_content() - return resp diff --git a/jigsawstack/helpers.py b/jigsawstack/helpers.py index 1854410..5c1ad6a 100644 --- a/jigsawstack/helpers.py +++ b/jigsawstack/helpers.py @@ -2,9 +2,7 @@ from urllib.parse import urlencode -def build_path( - base_path: str, params: Optional[Dict[str, Union[str, int, bool]]] = None -) -> str: +def build_path(base_path: str, params: Optional[Dict[str, Union[str, int, bool]]] = None) -> str: """ Build an API endpoint path with query parameters. @@ -20,9 +18,7 @@ def build_path( # remove None values from the parameters filtered_params = { - k: str(v).lower() if isinstance(v, bool) else v - for k, v in params.items() - if v is not None + k: str(v).lower() if isinstance(v, bool) else v for k, v in params.items() if v is not None } # encode the parameters diff --git a/jigsawstack/image_generation.py b/jigsawstack/image_generation.py index b868ada..9584cf3 100644 --- a/jigsawstack/image_generation.py +++ b/jigsawstack/image_generation.py @@ -1,10 +1,10 @@ -from typing import Any, Dict, List, Union, cast -from typing_extensions import NotRequired, TypedDict, Literal, Required -from .request import Request, RequestConfig -from .async_request import AsyncRequest +from typing import Any, Dict, Union, cast + +from typing_extensions import Literal, NotRequired, Required, TypedDict -from typing import List, Union from ._config import ClientConfig +from .async_request import AsyncRequest +from .request import Request, RequestConfig class AdvanceConfig(TypedDict): @@ -77,9 +77,9 @@ class ImageGenerationResponse(TypedDict): """ Indicates whether the image generation was successful. """ - image: bytes + url: NotRequired[str] """ - The generated image as a blob. + The generated image as a URL or base64 string. """ @@ -92,9 +92,7 @@ def __init__( api_url: str, disable_request_logging: Union[bool, None] = False, ): - super().__init__( - api_key, api_url, disable_request_logging=disable_request_logging - ) + super().__init__(api_key, api_url, disable_request_logging=disable_request_logging) self.config = RequestConfig( api_url=api_url, api_key=api_key, @@ -103,7 +101,7 @@ def __init__( def image_generation( self, params: ImageGenerationParams - ) -> ImageGenerationResponse: + ) -> Union[ImageGenerationResponse, bytes]: path = "/ai/image_generation" resp = Request( config=self.config, @@ -123,9 +121,7 @@ def __init__( api_url: str, disable_request_logging: Union[bool, None] = False, ): - super().__init__( - api_key, api_url, disable_request_logging=disable_request_logging - ) + super().__init__(api_key, api_url, disable_request_logging=disable_request_logging) self.config = RequestConfig( api_url=api_url, api_key=api_key, @@ -134,7 +130,7 @@ def __init__( async def image_generation( self, params: ImageGenerationParams - ) -> ImageGenerationResponse: + ) -> Union[ImageGenerationResponse, bytes]: path = "/ai/image_generation" resp = await AsyncRequest( config=self.config, diff --git a/jigsawstack/prediction.py b/jigsawstack/prediction.py index d24168b..ec571a4 100644 --- a/jigsawstack/prediction.py +++ b/jigsawstack/prediction.py @@ -1,15 +1,15 @@ from typing import Any, Dict, List, Union, cast -from typing_extensions import NotRequired, TypedDict -from .request import Request, RequestConfig -from .async_request import AsyncRequest -from typing import List, Union +from typing_extensions import TypedDict + from ._config import ClientConfig from ._types import BaseResponse +from .async_request import AsyncRequest +from .request import Request, RequestConfig class Dataset(TypedDict): - value: Union[int, str] + value: Union[int, float, str] """ The value of the dataset. """ @@ -27,11 +27,15 @@ class PredictionParams(TypedDict): """ steps: int """ - The number of predictions to make. The defualt is 5. + The number of predictions to make. The default is 5. """ class PredictionResponse(BaseResponse): + steps: int + """ + The number of steps predicted. + """ prediction: List[Dataset] """ The predictions made on the dataset. diff --git a/jigsawstack/prompt_engine.py b/jigsawstack/prompt_engine.py index 378e9b3..3af7fa3 100644 --- a/jigsawstack/prompt_engine.py +++ b/jigsawstack/prompt_engine.py @@ -1,10 +1,11 @@ -from typing import Any, Dict, List, Union, cast, Generator, Literal +from typing import Any, Dict, Generator, List, Literal, Union, cast + from typing_extensions import NotRequired, TypedDict -from .request import Request, RequestConfig -from .async_request import AsyncRequest -from typing import List, Union + from ._config import ClientConfig +from .async_request import AsyncRequest from .helpers import build_path +from .request import Request, RequestConfig class PromptEngineResult(TypedDict): @@ -118,14 +119,10 @@ def create(self, params: PromptEngineCreateParams) -> PromptEngineCreateResponse def get(self, id: str) -> PromptEngineGetResponse: path = f"/prompt_engine/{id}" - resp = Request( - config=self.config, path=path, params={}, verb="get" - ).perform_with_content() + resp = Request(config=self.config, path=path, params={}, verb="get").perform_with_content() return resp - def list( - self, params: Union[PromptEngineListParams, None] = None - ) -> PromptEngineListResponse: + def list(self, params: Union[PromptEngineListParams, None] = None) -> PromptEngineListResponse: if params is None: params = {} @@ -140,9 +137,7 @@ def list( base_path="/prompt_engine", params=params, ) - resp = Request( - config=self.config, path=path, params={}, verb="get" - ).perform_with_content() + resp = Request(config=self.config, path=path, params={}, verb="get").perform_with_content() return resp def delete(self, id: str) -> PromptEngineDeleteResponse: @@ -218,9 +213,7 @@ def __init__( disable_request_logging=disable_request_logging, ) - async def create( - self, params: PromptEngineCreateParams - ) -> PromptEngineCreateResponse: + async def create(self, params: PromptEngineCreateParams) -> PromptEngineCreateResponse: path = "/prompt_engine" resp = await AsyncRequest( config=self.config, diff --git a/jigsawstack/request.py b/jigsawstack/request.py index 68ac675..c1967a4 100644 --- a/jigsawstack/request.py +++ b/jigsawstack/request.py @@ -1,8 +1,10 @@ -from typing import Any, Dict, Generic, List, Union, cast, TypedDict, Generator +import json +from typing import Any, Dict, Generator, Generic, List, TypedDict, Union, cast + import requests from typing_extensions import Literal, TypeVar + from .exceptions import NoContentError, raise_for_code_and_type -import json RequestVerb = Literal["get", "post", "put", "patch", "delete"] @@ -23,7 +25,7 @@ def __init__( path: str, params: Union[Dict[Any, Any], List[Dict[Any, Any]]], verb: RequestVerb, - headers: Dict[str, str] = {"Content-Type": "application/json"}, + headers: Dict[str, str] = None, data: Union[bytes, None] = None, stream: Union[bool, None] = False, ): @@ -33,7 +35,7 @@ def __init__( self.api_url = config.get("api_url") self.api_key = config.get("api_key") self.data = data - self.headers = headers + self.headers = headers or {"Content-Type": "application/json"} self.disable_request_logging = config.get("disable_request_logging") self.stream = stream @@ -89,10 +91,7 @@ def perform_file(self) -> Union[T, None]: # handle error in case there is a statusCode attr present # and status != 200 and response is a json. - if ( - "application/json" not in resp.headers["content-type"] - and resp.status_code != 200 - ): + if "application/json" not in resp.headers["content-type"] and resp.status_code != 200: raise_for_code_and_type( code=500, message="Failed to parse JigsawStack API response. Please try again.", diff --git a/jigsawstack/search.py b/jigsawstack/search.py index 4b10884..21b0187 100644 --- a/jigsawstack/search.py +++ b/jigsawstack/search.py @@ -1,9 +1,11 @@ -from typing import Any, Dict, List, Union, cast, Literal -from typing_extensions import NotRequired, TypedDict, Optional -from .request import Request, RequestConfig -from .async_request import AsyncRequest, AsyncRequestConfig +from typing import Any, Dict, List, Literal, Optional, Union, cast + +from typing_extensions import NotRequired, TypedDict + from ._config import ClientConfig from ._types import BaseResponse +from .async_request import AsyncRequest, AsyncRequestConfig +from .request import Request, RequestConfig class RelatedIndex(TypedDict): @@ -247,7 +249,7 @@ def search(self, params: SearchParams) -> SearchResponse: "spell_check": spell_check, } - path = f"/web/search" + path = "/web/search" resp = Request( config=self.config, path=path, @@ -269,7 +271,7 @@ def suggestions(self, params: SearchSuggestionsParams) -> SearchSuggestionsRespo return resp def deep_research(self, params: DeepResearchParams) -> DeepResearchResponse: - path = f"/web/deep_research" + path = "/web/deep_research" resp = Request( config=self.config, path=path, @@ -296,7 +298,7 @@ def __init__( ) async def search(self, params: SearchParams) -> SearchResponse: - path = f"/web/search" + path = "/web/search" query = params["query"] ai_overview = params.get("ai_overview", "True") safe_search = params.get("safe_search", "moderate") @@ -317,9 +319,7 @@ async def search(self, params: SearchParams) -> SearchResponse: ).perform_with_content() return resp - async def suggestions( - self, params: SearchSuggestionsParams - ) -> SearchSuggestionsResponse: + async def suggestions(self, params: SearchSuggestionsParams) -> SearchSuggestionsResponse: query = params["query"] path = f"/web/search/suggest?query={query}" resp = await AsyncRequest( @@ -331,7 +331,7 @@ async def suggestions( return resp async def deep_research(self, params: DeepResearchParams) -> DeepResearchResponse: - path = f"/web/deep_research" + path = "/web/deep_research" resp = await AsyncRequest( config=self.config, path=path, diff --git a/jigsawstack/sentiment.py b/jigsawstack/sentiment.py index 8970110..ef5e9df 100644 --- a/jigsawstack/sentiment.py +++ b/jigsawstack/sentiment.py @@ -1,10 +1,11 @@ from typing import Any, Dict, List, Union, cast -from typing_extensions import NotRequired, TypedDict -from .request import Request, RequestConfig -from .async_request import AsyncRequest -from typing import List, Union + +from typing_extensions import TypedDict + from ._config import ClientConfig from ._types import BaseResponse +from .async_request import AsyncRequest +from .request import Request, RequestConfig class SentimentParams(TypedDict): diff --git a/jigsawstack/sql.py b/jigsawstack/sql.py index d2dfc3b..b895485 100644 --- a/jigsawstack/sql.py +++ b/jigsawstack/sql.py @@ -1,10 +1,11 @@ -from typing import Any, Dict, List, Union, cast, Literal +from typing import Any, Dict, Literal, Union, cast + from typing_extensions import NotRequired, TypedDict -from .request import Request, RequestConfig -from .async_request import AsyncRequest -from typing import List, Union + from ._config import ClientConfig from ._types import BaseResponse +from .async_request import AsyncRequest +from .request import Request, RequestConfig class SQLParams(TypedDict): diff --git a/jigsawstack/store.py b/jigsawstack/store.py index 72bf191..0693f49 100644 --- a/jigsawstack/store.py +++ b/jigsawstack/store.py @@ -1,10 +1,11 @@ -from typing import Any, Dict, List, Union, cast +from typing import Any, Union + from typing_extensions import NotRequired, TypedDict -from .request import Request, RequestConfig -from .async_request import AsyncRequest, AsyncRequestConfig + from ._config import ClientConfig +from .async_request import AsyncRequest, AsyncRequestConfig from .helpers import build_path -from .exceptions import JigsawStackError +from .request import Request, RequestConfig class FileDeleteResponse(TypedDict): @@ -22,9 +23,7 @@ class FileUploadResponse(TypedDict): key: str url: str size: int - temp_public_url: NotRequired[ - str - ] # Optional, only if temp_public_url is set to True in params + temp_public_url: NotRequired[str] # Optional, only if temp_public_url is set to True in params class Store(ClientConfig): diff --git a/jigsawstack/summary.py b/jigsawstack/summary.py index 898c42b..0d19b39 100644 --- a/jigsawstack/summary.py +++ b/jigsawstack/summary.py @@ -1,10 +1,11 @@ -from typing import Any, Dict, List, Union, cast, Literal +from typing import Any, Dict, List, Literal, Union, cast + from typing_extensions import NotRequired, TypedDict -from .request import Request, RequestConfig -from .async_request import AsyncRequest -from typing import List, Union + from ._config import ClientConfig from ._types import BaseResponse +from .async_request import AsyncRequest +from .request import Request, RequestConfig class SummaryParams(TypedDict): diff --git a/jigsawstack/translate.py b/jigsawstack/translate.py index 14d225a..63b7fa5 100644 --- a/jigsawstack/translate.py +++ b/jigsawstack/translate.py @@ -1,11 +1,12 @@ from typing import Any, Dict, List, Union, cast, overload -from typing_extensions import NotRequired, TypedDict, Literal -from .request import Request, RequestConfig -from .async_request import AsyncRequest -from typing import List, Union + +from typing_extensions import Literal, NotRequired, TypedDict + from ._config import ClientConfig -from .helpers import build_path from ._types import BaseResponse +from .async_request import AsyncRequest +from .helpers import build_path +from .request import Request, RequestConfig class TranslateImageParams(TypedDict): @@ -50,10 +51,10 @@ class TranslateResponse(BaseResponse): """ -class TranslateImageResponse(TypedDict): - image: bytes +class TranslateImageResponse(BaseResponse): + url: str """ - The image data that was translated. + The URL or base64 of the translated image. """ @@ -83,17 +84,17 @@ def text(self, params: TranslateParams) -> TranslateResponse: return resp @overload - def image(self, params: TranslateImageParams) -> TranslateImageResponse: ... + def image(self, params: TranslateImageParams) -> Union[TranslateImageResponse, bytes]: ... @overload def image( self, blob: bytes, options: TranslateImageParams = None - ) -> TranslateImageParams: ... + ) -> Union[TranslateImageResponse, bytes]: ... def image( self, blob: Union[TranslateImageParams, bytes], options: TranslateImageParams = None, - ) -> TranslateImageResponse: + ) -> Union[TranslateImageResponse, bytes]: if isinstance( blob, dict ): # If params is provided as a dict, we assume it's the first argument @@ -147,17 +148,17 @@ async def text(self, params: TranslateParams) -> TranslateResponse: return resp @overload - async def image(self, params: TranslateImageParams) -> TranslateImageResponse: ... + async def image(self, params: TranslateImageParams) -> Union[TranslateImageResponse, bytes]: ... @overload async def image( self, blob: bytes, options: TranslateImageParams = None - ) -> TranslateImageParams: ... + ) -> Union[TranslateImageResponse, bytes]: ... async def image( self, blob: Union[TranslateImageParams, bytes], options: TranslateImageParams = None, - ) -> TranslateImageResponse: + ) -> Union[TranslateImageResponse, bytes]: if isinstance(blob, dict): resp = await AsyncRequest( config=self.config, diff --git a/jigsawstack/validate.py b/jigsawstack/validate.py index 1d4f715..fc57c3c 100644 --- a/jigsawstack/validate.py +++ b/jigsawstack/validate.py @@ -1,12 +1,12 @@ from typing import Any, Dict, List, Union, cast, overload + from typing_extensions import NotRequired, TypedDict -from .request import Request, RequestConfig -from .async_request import AsyncRequest, AsyncRequestConfig + from ._config import ClientConfig -from typing import Any, Dict, List, cast -from typing_extensions import NotRequired, TypedDict, Union, Optional -from .helpers import build_path from ._types import BaseResponse +from .async_request import AsyncRequest, AsyncRequestConfig +from .helpers import build_path +from .request import Request, RequestConfig class Spam(TypedDict): diff --git a/jigsawstack/vision.py b/jigsawstack/vision.py index 4bb6ff5..6df4e37 100644 --- a/jigsawstack/vision.py +++ b/jigsawstack/vision.py @@ -1,10 +1,11 @@ -from typing import Any, Dict, List, Union, cast, Optional, overload -from typing_extensions import NotRequired, TypedDict, Literal -from .request import Request, RequestConfig -from .async_request import AsyncRequest, AsyncRequestConfig +from typing import Any, Dict, List, Optional, Union, cast, overload + +from typing_extensions import Literal, NotRequired, TypedDict + from ._config import ClientConfig -from .helpers import build_path from ._types import BaseResponse +from .async_request import AsyncRequest, AsyncRequestConfig +from .request import Request, RequestConfig class Point(TypedDict): @@ -190,6 +191,8 @@ def vocr( blob: Union[VOCRParams, bytes], options: VOCRParams = None, ) -> OCRResponse: + path = "/vocr" + options = options or {} if isinstance( blob, dict ): # If params is provided as a dict, we assume it's the first argument @@ -201,8 +204,6 @@ def vocr( ).perform_with_content() return resp - options = options or {} - path = build_path(base_path="/vocr", params=options) content_type = options.get("content_type", "application/octet-stream") headers = {"Content-Type": content_type} @@ -217,9 +218,7 @@ def vocr( return resp @overload - def object_detection( - self, params: ObjectDetectionParams - ) -> ObjectDetectionResponse: ... + def object_detection(self, params: ObjectDetectionParams) -> ObjectDetectionResponse: ... @overload def object_detection( self, blob: bytes, options: ObjectDetectionParams = None @@ -230,17 +229,17 @@ def object_detection( blob: Union[ObjectDetectionParams, bytes], options: ObjectDetectionParams = None, ) -> ObjectDetectionResponse: + path = "/object_detection" + options = options or {} if isinstance(blob, dict): resp = Request( config=self.config, - path="/object_detection", + path=path, params=cast(Dict[Any, Any], blob), verb="post", ).perform_with_content() return resp - options = options or {} - path = build_path(base_path="/object_detection", params=options) content_type = options.get("content_type", "application/octet-stream") headers = {"Content-Type": content_type} @@ -281,17 +280,17 @@ async def vocr( blob: Union[VOCRParams, bytes], options: VOCRParams = None, ) -> OCRResponse: + path = "/vocr" + options = options or {} if isinstance(blob, dict): resp = await AsyncRequest( config=self.config, - path="/vocr", + path=path, params=cast(Dict[Any, Any], blob), verb="post", ).perform_with_content() return resp - options = options or {} - path = build_path(base_path="/vocr", params=options) content_type = options.get("content_type", "application/octet-stream") headers = {"Content-Type": content_type} @@ -306,9 +305,7 @@ async def vocr( return resp @overload - async def object_detection( - self, params: ObjectDetectionParams - ) -> ObjectDetectionResponse: ... + async def object_detection(self, params: ObjectDetectionParams) -> ObjectDetectionResponse: ... @overload async def object_detection( self, blob: bytes, options: ObjectDetectionParams = None @@ -319,19 +316,19 @@ async def object_detection( blob: Union[ObjectDetectionParams, bytes], options: ObjectDetectionParams = None, ) -> ObjectDetectionResponse: + path = "/object_detection" + options = options or {} if isinstance( blob, dict ): # If params is provided as a dict, we assume it's the first argument resp = await AsyncRequest( config=self.config, - path="/object_detection", + path=path, params=cast(Dict[Any, Any], blob), verb="post", ).perform_with_content() return resp - options = options or {} - path = build_path(base_path="/object_detection", params=options) content_type = options.get("content_type", "application/octet-stream") headers = {"Content-Type": content_type} diff --git a/jigsawstack/web.py b/jigsawstack/web.py index 58d9307..5d400c3 100644 --- a/jigsawstack/web.py +++ b/jigsawstack/web.py @@ -1,27 +1,26 @@ -from typing import Any, Dict, List, Union, Optional, cast, Literal, overload +from typing import Any, Dict, List, Literal, Optional, Union, cast, overload + from typing_extensions import NotRequired, TypedDict -from .request import Request, RequestConfig -from .async_request import AsyncRequest, AsyncRequestConfig from ._config import ClientConfig +from ._types import BaseResponse +from .async_request import AsyncRequest, AsyncRequestConfig +from .request import Request, RequestConfig from .search import ( + AsyncSearch, + DeepResearchParams, + DeepResearchResponse, Search, SearchParams, + SearchResponse, SearchSuggestionsParams, SearchSuggestionsResponse, - SearchResponse, - AsyncSearch, - DeepResearchParams, - DeepResearchResponse, ) -from ._types import BaseResponse class GotoOptions(TypedDict): timeout: NotRequired[int] - wait_until: NotRequired[ - Literal["load", "domcontentloaded", "networkidle0", "networkidle2"] - ] + wait_until: NotRequired[Literal["load", "domcontentloaded", "networkidle0", "networkidle2"]] # @@ -256,9 +255,7 @@ def search(self, params: SearchParams) -> SearchResponse: ) return s.search(params) - def search_suggestions( - self, params: SearchSuggestionsParams - ) -> SearchSuggestionsResponse: + def search_suggestions(self, params: SearchSuggestionsParams) -> SearchSuggestionsResponse: s = Search( self.api_key, self.api_url, @@ -308,9 +305,7 @@ async def ai_scrape(self, params: AIScrapeParams) -> AIScrapeResponse: async def html_to_any(self, params: HTMLToAnyURLParams) -> HTMLToAnyURLResponse: ... @overload - async def html_to_any( - self, params: HTMLToAnyBinaryParams - ) -> HTMLToAnyBinaryResponse: ... + async def html_to_any(self, params: HTMLToAnyBinaryParams) -> HTMLToAnyBinaryResponse: ... async def html_to_any( self, params: HTMLToAnyParams diff --git a/requirements.txt b/requirements.txt index 0a1a976..351d200 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ requests>=2.31.0 typing_extensions -aiohttp \ No newline at end of file +aiohttp>=3.12.15 \ No newline at end of file diff --git a/tests/test_ai_scrape.py b/tests/test_ai_scrape.py new file mode 100644 index 0000000..4c30b33 --- /dev/null +++ b/tests/test_ai_scrape.py @@ -0,0 +1,141 @@ +import logging +import os + +import pytest +from dotenv import load_dotenv + +import jigsawstack +from jigsawstack.exceptions import JigsawStackError + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +jigsaw = jigsawstack.JigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) +async_jigsaw = jigsawstack.AsyncJigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) + +URL = "https://jigsawstack.com" + +# AI Scrape Test Cases +AI_SCRAPE_TEST_CASES = [ + { + "name": "scrape_with_element_prompts", + "params": { + "url": URL, + "element_prompts": ["title", "main content", "navigation links"], + }, + }, + { + "name": "scrape_with_selectors", + "params": { + "url": URL, + "selectors": ["h1", "p", "a"], + }, + }, + { + "name": "scrape_with_features", + "params": { + "url": URL, + "element_prompts": ["title"], + "features": ["meta", "link"], + }, + }, + { + "name": "scrape_with_root_element", + "params": { + "url": URL, + "element_prompts": ["content"], + "root_element_selector": "main", + }, + }, + { + "name": "scrape_with_wait_for_timeout", + "params": { + "url": URL, + "element_prompts": ["content"], + "wait_for": {"mode": "timeout", "value": 3000}, + }, + }, + { + "name": "scrape_mobile_view", + "params": { + "url": URL, + "element_prompts": ["mobile menu"], + "is_mobile": True, + }, + }, + { + "name": "scrape_with_cookies", + "params": { + "url": URL, + "element_prompts": ["user data"], + "cookies": [{"name": "session", "value": "test123", "domain": "example.com"}], + }, + }, + { + "name": "scrape_with_advance_config", + "params": { + "url": URL, + "element_prompts": ["content"], + "advance_config": {"console": True, "network": True, "cookies": True}, + }, + }, +] + + +class TestAIScrapeSync: + """Test synchronous AI scrape methods""" + + @pytest.mark.parametrize( + "test_case", + AI_SCRAPE_TEST_CASES, + ids=[tc["name"] for tc in AI_SCRAPE_TEST_CASES], + ) + def test_ai_scrape(self, test_case): + """Test synchronous AI scrape with various inputs""" + try: + result = jigsaw.web.ai_scrape(test_case["params"]) + + assert result["success"] + assert "data" in result + assert isinstance(result["data"], list) + + # Check for optional features + if "meta" in test_case["params"].get("features", []): + assert "meta" in result + if "link" in test_case["params"].get("features", []): + assert "link" in result + assert isinstance(result["link"], list) + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestAIScrapeAsync: + """Test asynchronous AI scrape methods""" + + @pytest.mark.parametrize( + "test_case", + AI_SCRAPE_TEST_CASES, + ids=[tc["name"] for tc in AI_SCRAPE_TEST_CASES], + ) + @pytest.mark.asyncio + async def test_ai_scrape_async(self, test_case): + """Test asynchronous AI scrape with various inputs""" + try: + result = await async_jigsaw.web.ai_scrape(test_case["params"]) + + assert result["success"] + assert "data" in result + assert isinstance(result["data"], list) + + # Check for optional features + if "meta" in test_case["params"].get("features", []): + assert "meta" in result + if "link" in test_case["params"].get("features", []): + assert "link" in result + assert isinstance(result["link"], list) + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") diff --git a/tests/test_async_web.py b/tests/test_async_web.py deleted file mode 100644 index 99899a8..0000000 --- a/tests/test_async_web.py +++ /dev/null @@ -1,36 +0,0 @@ -from unittest.mock import MagicMock -import unittest -from jigsawstack.exceptions import JigsawStackError -from jigsawstack import AsyncJigsawStack -import pytest -import asyncio -import logging - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_async_web_search_response(): - async def _test(): - client = AsyncJigsawStack() - try: - result = await client.web.search({"query": "JigsawStack fund raising"}) - # logger.info(result) - assert result["success"] == True - except JigsawStackError as e: - pytest.fail(f"Unexpected JigsawStackError: {e}") - - asyncio.run(_test()) - - -def test_async_web_search_suggestion_response(): - async def _test(): - client = AsyncJigsawStack() - try: - result = await client.web.search_suggestion({"query": "Lagos"}) - logger.info(result) - assert result["success"] == True - except JigsawStackError as e: - pytest.fail(f"Unexpected JigsawStackError: {e}") - - asyncio.run(_test()) diff --git a/tests/test_audio.py b/tests/test_audio.py new file mode 100644 index 0000000..037f285 --- /dev/null +++ b/tests/test_audio.py @@ -0,0 +1,220 @@ +import logging +import os + +import pytest +import requests +from dotenv import load_dotenv + +import jigsawstack +from jigsawstack.exceptions import JigsawStackError + +load_dotenv() +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +jigsaw = jigsawstack.JigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) +async_jigsaw = jigsawstack.AsyncJigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) + +# Sample audio URLs for testing +AUDIO_URL = AUDIO_URL_LONG = "https://jigsawstack.com/preview/stt-example.wav" + + +TEST_CASES = [ + { + "name": "with_url_only", + "params": {"url": AUDIO_URL}, + "blob": None, + "options": None, + }, + { + "name": "with_url_and_language", + "params": {"url": AUDIO_URL, "language": "en"}, + "blob": None, + "options": None, + }, + { + "name": "with_url_auto_detect_language", + "params": {"url": AUDIO_URL, "language": "auto"}, + "blob": None, + "options": None, + }, + { + "name": "with_url_and_translate", + "params": {"url": AUDIO_URL, "translate": True}, + "blob": None, + "options": None, + }, + { + "name": "with_blob_only", + "params": None, + "blob": AUDIO_URL, + "options": None, + }, + { + "name": "with_blob_and_language", + "params": None, + "blob": AUDIO_URL, + "options": {"language": "en"}, + }, + { + "name": "with_blob_auto_detect", + "params": None, + "blob": AUDIO_URL, + "options": {"language": "auto"}, + }, + { + "name": "with_blob_and_translate", + "params": None, + "blob": AUDIO_URL, + "options": {"translate": True, "language": "en"}, + }, + { + "name": "with_by_speaker", + "params": {"url": AUDIO_URL_LONG, "by_speaker": True}, + "blob": None, + "options": None, + }, + { + "name": "with_chunk_settings", + "params": {"url": AUDIO_URL, "batch_size": 5, "chunk_duration": 15}, + "blob": None, + "options": None, + }, + { + "name": "with_all_options", + "params": None, + "blob": AUDIO_URL_LONG, + "options": { + "language": "auto", + "translate": False, + "by_speaker": True, + "batch_size": 10, + "chunk_duration": 15, + }, + }, +] + +# Test cases with webhook (separate as they return different response) +WEBHOOK_TEST_CASES = [ + { + "name": "with_webhook_url", + "params": {"url": AUDIO_URL, "webhook_url": "https://webhook.site/test-webhook"}, + "blob": None, + "options": None, + }, + { + "name": "with_blob_and_webhook", + "params": None, + "blob": AUDIO_URL, + "options": {"webhook_url": "https://webhook.site/test-webhook", "language": "en"}, + }, +] + + +class TestAudioSync: + """Test synchronous audio speech-to-text methods""" + + @pytest.mark.parametrize("test_case", TEST_CASES, ids=[tc["name"] for tc in TEST_CASES]) + def test_speech_to_text(self, test_case): + """Test synchronous speech-to-text with various inputs""" + try: + if test_case.get("blob"): + # Download audio content + blob_content = requests.get(test_case["blob"]).content + result = jigsaw.audio.speech_to_text(blob_content, test_case.get("options", {})) + else: + # Use params directly + result = jigsaw.audio.speech_to_text(test_case["params"]) + # Verify response structure + assert result["success"] + assert result.get("text", None) is not None and isinstance(result["text"], str) + + # Check for chunks + if result.get("chunks", None): + assert isinstance(result["chunks"], list) + + # Check for speaker diarization if requested + if result.get("speakers", None): + assert isinstance(result["speakers"], list) + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + @pytest.mark.parametrize( + "test_case", WEBHOOK_TEST_CASES, ids=[tc["name"] for tc in WEBHOOK_TEST_CASES] + ) + def test_speech_to_text_webhook(self, test_case): + """Test synchronous speech-to-text with webhook""" + try: + if test_case.get("blob"): + # Download audio content + blob_content = requests.get(test_case["blob"]).content + result = jigsaw.audio.speech_to_text(blob_content, test_case.get("options", {})) + else: + # Use params directly + result = jigsaw.audio.speech_to_text(test_case["params"]) + # Verify webhook response structure + assert result["success"] + + except JigsawStackError as e: + # Webhook URLs might fail if invalid + print(f"Expected possible error for webhook test {test_case['name']}: {e}") + + +class TestAudioAsync: + """Test asynchronous audio speech-to-text methods""" + + @pytest.mark.parametrize("test_case", TEST_CASES, ids=[tc["name"] for tc in TEST_CASES]) + @pytest.mark.asyncio + async def test_speech_to_text_async(self, test_case): + """Test asynchronous speech-to-text with various inputs""" + try: + if test_case.get("blob"): + # Download audio content + blob_content = requests.get(test_case["blob"]).content + result = await async_jigsaw.audio.speech_to_text( + blob_content, test_case.get("options", {}) + ) + else: + # Use params directly + result = await async_jigsaw.audio.speech_to_text(test_case["params"]) + + # Verify response structure + assert result["success"] + assert result.get("text", None) is not None and isinstance(result["text"], str) + + # Check for chunks + if result.get("chunks", None): + assert isinstance(result["chunks"], list) + + # Check for speaker diarization if requested + if result.get("speakers", None): + assert isinstance(result["speakers"], list) + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in async {test_case['name']}: {e}") + + @pytest.mark.parametrize( + "test_case", WEBHOOK_TEST_CASES, ids=[tc["name"] for tc in WEBHOOK_TEST_CASES] + ) + @pytest.mark.asyncio + async def test_speech_to_text_webhook_async(self, test_case): + """Test asynchronous speech-to-text with webhook""" + try: + if test_case.get("blob"): + # Download audio content + blob_content = requests.get(test_case["blob"]).content + result = await async_jigsaw.audio.speech_to_text( + blob_content, test_case.get("options", {}) + ) + else: + # Use params directly + result = await async_jigsaw.audio.speech_to_text(test_case["params"]) + + print(f"Async test {test_case['name']}: Webhook response") + + # Verify webhook response structure + assert result["success"] + + except JigsawStackError as e: + # Webhook URLs might fail if invalid + print(f"Expected possible error for async webhook test {test_case['name']}: {e}") diff --git a/tests/test_classification.py b/tests/test_classification.py index 6c301c5..dba924a 100644 --- a/tests/test_classification.py +++ b/tests/test_classification.py @@ -1,75 +1,120 @@ -from jigsawstack.exceptions import JigsawStackError -from jigsawstack import JigsawStack +import logging +import os import pytest +from dotenv import load_dotenv + +import jigsawstack +from jigsawstack.exceptions import JigsawStackError -# flake8: noqa +load_dotenv() -client = JigsawStack() +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +jigsaw = jigsawstack.JigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) +async_jigsaw = jigsawstack.AsyncJigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) -@pytest.mark.parametrize( - "dataset,labels", - [ - ( - [ +TEST_CASES = [ + { + "name": "text_classification_programming", + "params": { + "dataset": [ {"type": "text", "value": "I love programming"}, {"type": "text", "value": "I love reading books"}, {"type": "text", "value": "I love watching movies"}, {"type": "text", "value": "I love playing games"}, ], - [ + "labels": [ {"type": "text", "value": "programming"}, {"type": "text", "value": "reading"}, {"type": "text", "value": "watching"}, {"type": "text", "value": "playing"}, ], - ), - ( - [ + }, + }, + { + "name": "text_classification_sentiment", + "params": { + "dataset": [ {"type": "text", "value": "This is awesome!"}, {"type": "text", "value": "I hate this product"}, {"type": "text", "value": "It's okay, nothing special"}, ], - [ + "labels": [ {"type": "text", "value": "positive"}, {"type": "text", "value": "negative"}, {"type": "text", "value": "neutral"}, ], - ), - ( - [ + }, + }, + { + "name": "text_classification_weather", + "params": { + "dataset": [ {"type": "text", "value": "The weather is sunny today"}, {"type": "text", "value": "It's raining heavily outside"}, {"type": "text", "value": "Snow is falling gently"}, ], - [ + "labels": [ {"type": "text", "value": "sunny"}, {"type": "text", "value": "rainy"}, {"type": "text", "value": "snowy"}, ], - ), - ], -) -def test_classification_text_success_response(dataset, labels) -> None: - params = { - "dataset": dataset, - "labels": labels, - } - try: - result = client.classification.text(params) - print(result) - assert result["success"] == True - except JigsawStackError as e: - print(str(e)) - assert e.message == "Failed to parse API response. Please try again." - - -@pytest.mark.parametrize( - "dataset,labels", - [ - ( - [ + }, + }, + { + "name": "image_classification_fruits", + "params": { + "dataset": [ + { + "type": "image", + "value": "https://as2.ftcdn.net/v2/jpg/02/24/11/57/1000_F_224115780_2ssvcCoTfQrx68Qsl5NxtVIDFWKtAgq2.jpg", + }, + { + "type": "image", + "value": "https://t3.ftcdn.net/jpg/02/95/44/22/240_F_295442295_OXsXOmLmqBUfZreTnGo9PREuAPSLQhff.jpg", + }, + { + "type": "image", + "value": "https://as1.ftcdn.net/v2/jpg/05/54/94/46/1000_F_554944613_okdr3fBwcE9kTOgbLp4BrtVi8zcKFWdP.jpg", + }, + ], + "labels": [ + {"type": "text", "value": "banana"}, + { + "type": "image", + "value": "https://upload.wikimedia.org/wikipedia/commons/8/8a/Banana-Single.jpg", + }, + {"type": "text", "value": "kisses"}, + ], + }, + }, + { + "name": "text_classification_multiple_labels", + "params": { + "dataset": [ + { + "type": "text", + "value": "Python is a great programming language for data science", + }, + { + "type": "text", + "value": "JavaScript is essential for web development", + }, + ], + "labels": [ + {"type": "text", "value": "programming"}, + {"type": "text", "value": "data science"}, + {"type": "text", "value": "web development"}, + ], + "multiple_labels": True, + }, + }, + { + "name": "image_classification_with_multiple_labels", + "params": { + "dataset": [ { "type": "image", "value": "https://as2.ftcdn.net/v2/jpg/02/24/11/57/1000_F_224115780_2ssvcCoTfQrx68Qsl5NxtVIDFWKtAgq2.jpg", @@ -83,7 +128,7 @@ def test_classification_text_success_response(dataset, labels) -> None: "value": "https://as1.ftcdn.net/v2/jpg/05/54/94/46/1000_F_554944613_okdr3fBwcE9kTOgbLp4BrtVi8zcKFWdP.jpg", }, ], - [ + "labels": [ {"type": "text", "value": "banana"}, { "type": "image", @@ -91,18 +136,53 @@ def test_classification_text_success_response(dataset, labels) -> None: }, {"type": "text", "value": "kisses"}, ], - ), - ], -) -def test_classification_image_success_response(dataset, labels) -> None: - params = { - "dataset": dataset, - "labels": labels, - } - try: - result = client.classification.image(params) - print(result) - assert result["success"] == True - except JigsawStackError as e: - print(str(e)) - assert e.message == "Failed to parse API response. Please try again." + }, + }, +] + + +class TestClassificationSync: + """Test synchronous classification methods""" + + sync_test_cases = TEST_CASES + + @pytest.mark.parametrize( + "test_case", sync_test_cases, ids=[tc["name"] for tc in sync_test_cases] + ) + def test_classification(self, test_case): + """Test synchronous classification with various inputs""" + try: + result = jigsaw.classification(test_case["params"]) + assert result["success"] + assert "predictions" in result + if test_case.get("multiple_labels"): + # Ensure predictions are lists when multiple_labels is True + for prediction in result["predictions"]: + assert isinstance(prediction, list) + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestClassificationAsync: + """Test asynchronous classification methods""" + + async_test_cases = TEST_CASES + + @pytest.mark.parametrize( + "test_case", async_test_cases, ids=[tc["name"] for tc in async_test_cases] + ) + @pytest.mark.asyncio + async def test_classification_async(self, test_case): + """Test asynchronous classification with various inputs""" + try: + result = await async_jigsaw.classification(test_case["params"]) + assert result["success"] + assert "predictions" in result + + if test_case.get("multiple_labels"): + # Ensure predictions are lists when multiple_labels is True + for prediction in result["predictions"]: + assert isinstance(prediction, list) + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") diff --git a/tests/test_deep_research.py b/tests/test_deep_research.py new file mode 100644 index 0000000..3d584ab --- /dev/null +++ b/tests/test_deep_research.py @@ -0,0 +1,95 @@ +import logging +import os + +import pytest +from dotenv import load_dotenv + +import jigsawstack +from jigsawstack.exceptions import JigsawStackError + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +jigsaw = jigsawstack.JigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) +async_jigsaw = jigsawstack.AsyncJigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) + +URL = "https://jigsawstack.com" + + +# Deep Research Test Cases +DEEP_RESEARCH_TEST_CASES = [ + { + "name": "basic_deep_research", + "params": { + "query": "climate change effects", + }, + }, + { + "name": "technical_deep_research", + "params": { + "query": "quantum computing applications in cryptography", + }, + }, + { + "name": "deep_research_with_depth", + "params": { + "query": "renewable energy sources", + "depth": 2, + }, + }, +] + + +class TestDeepResearchSync: + """Test synchronous deep research methods""" + + @pytest.mark.parametrize( + "test_case", + DEEP_RESEARCH_TEST_CASES, + ids=[tc["name"] for tc in DEEP_RESEARCH_TEST_CASES], + ) + def test_deep_research(self, test_case): + """Test synchronous deep research with various inputs""" + try: + result = jigsaw.web.deep_research(test_case["params"]) + + assert result["success"] + assert "results" in result + assert isinstance(result["results"], str) + assert len(result["results"]) > 0 + + # Check for sources + if "sources" in result: + assert isinstance(result["sources"], list) + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestDeepResearchAsync: + """Test asynchronous deep research methods""" + + @pytest.mark.parametrize( + "test_case", + DEEP_RESEARCH_TEST_CASES, + ids=[tc["name"] for tc in DEEP_RESEARCH_TEST_CASES], + ) + @pytest.mark.asyncio + async def test_deep_research_async(self, test_case): + """Test asynchronous deep research with various inputs""" + try: + result = await async_jigsaw.web.deep_research(test_case["params"]) + + assert result["success"] + assert "results" in result + assert isinstance(result["results"], str) + assert len(result["results"]) > 0 + + # Check for sources + if "sources" in result: + assert isinstance(result["sources"], list) + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") diff --git a/tests/test_embedding.py b/tests/test_embedding.py new file mode 100644 index 0000000..7b6b368 --- /dev/null +++ b/tests/test_embedding.py @@ -0,0 +1,325 @@ +import logging +import os + +import pytest +import requests +from dotenv import load_dotenv + +import jigsawstack +from jigsawstack.exceptions import JigsawStackError + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +jigsaw = jigsawstack.JigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) +async_jigsaw = jigsawstack.AsyncJigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) + +SAMPLE_TEXT = ( + "The quick brown fox jumps over the lazy dog. This is a sample text for embedding generation." +) +SAMPLE_IMAGE_URL = "https://images.unsplash.com/photo-1542931287-023b922fa89b?q=80&w=2574&auto=format&fit=crop&ixlib=rb-4.1.0&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" +SAMPLE_AUDIO_URL = "https://jigsawstack.com/preview/stt-example.wav" +SAMPLE_PDF_URL = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf" + +# Test cases for Embedding V1 +EMBEDDING_V1_TEST_CASES = [ + { + "name": "text_embedding_basic", + "params": { + "type": "text", + "text": SAMPLE_TEXT, + }, + }, + { + "name": "text_embedding_with_truncate", + "params": { + "type": "text", + "text": SAMPLE_TEXT * 100, # Long text to test truncation + "token_overflow_mode": "truncate", + }, + }, + { + "name": "text_embedding_with_error_mode", + "params": { + "type": "text", + "text": SAMPLE_TEXT, + "token_overflow_mode": "error", + }, + }, + { + "name": "image_embedding_from_url", + "params": { + "type": "image", + "url": SAMPLE_IMAGE_URL, + }, + }, + { + "name": "audio_embedding_from_url", + "params": { + "type": "audio", + "url": SAMPLE_AUDIO_URL, + }, + }, + { + "name": "pdf_embedding_from_url", + "params": { + "type": "pdf", + "url": SAMPLE_PDF_URL, + }, + }, + { + "name": "text_other_type", + "params": { + "type": "text-other", + "text": "This is a different text type for embedding", + }, + }, +] + +# Test cases for Embedding V2 +EMBEDDING_V2_TEST_CASES = [ + { + "name": "text_embedding_v2_basic", + "params": { + "type": "text", + "text": SAMPLE_TEXT, + }, + }, + { + "name": "text_embedding_v2_with_error", + "params": { + "type": "text", + "text": SAMPLE_TEXT * 100, # Long text to test chunking + "token_overflow_mode": "error", + }, + }, + { + "name": "text_embedding_v2_with_truncate", + "params": { + "type": "text", + "text": SAMPLE_TEXT * 100, + "token_overflow_mode": "truncate", + }, + }, + { + "name": "text_embedding_v2_with_error_mode", + "params": { + "type": "text", + "text": SAMPLE_TEXT, + "token_overflow_mode": "error", + }, + }, + { + "name": "image_embedding_v2_from_url", + "params": { + "type": "image", + "url": SAMPLE_IMAGE_URL, + }, + }, + { + "name": "audio_embedding_v2_basic", + "params": { + "type": "audio", + "url": SAMPLE_AUDIO_URL, + }, + }, + { + "name": "audio_embedding_v2_with_speaker_fingerprint", + "params": { + "type": "audio", + "url": SAMPLE_AUDIO_URL, + "speaker_fingerprint": True, + }, + }, + { + "name": "pdf_embedding_v2_from_url", + "params": { + "type": "pdf", + "url": SAMPLE_PDF_URL, + }, + }, +] + +# Test cases for blob inputs +BLOB_TEST_CASES = [ + { + "name": "image_blob_embedding", + "blob_url": SAMPLE_IMAGE_URL, + "options": { + "type": "image", + }, + }, + { + "name": "pdf_blob_embedding", + "blob_url": SAMPLE_PDF_URL, + "options": { + "type": "pdf", + }, + }, +] + + +class TestEmbeddingV1Sync: + """Test synchronous Embedding V1 methods""" + + sync_test_cases = EMBEDDING_V1_TEST_CASES + + @pytest.mark.parametrize( + "test_case", sync_test_cases, ids=[tc["name"] for tc in sync_test_cases] + ) + def test_embedding_v1(self, test_case): + """Test synchronous embedding v1 with various inputs""" + try: + result = jigsaw.embedding(test_case["params"]) + assert result["success"] + assert "embeddings" in result + assert isinstance(result["embeddings"], list) + if "chunks" in result: + assert isinstance(result["chunks"], list) + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + @pytest.mark.parametrize( + "test_case", BLOB_TEST_CASES, ids=[tc["name"] for tc in BLOB_TEST_CASES] + ) + def test_embedding_v1_blob(self, test_case): + """Test synchronous embedding v1 with blob inputs""" + try: + # Download blob content + blob_content = requests.get(test_case["blob_url"]).content + result = jigsaw.embedding(blob_content, test_case["options"]) + assert result["success"] + assert "embeddings" in result + assert isinstance(result["embeddings"], list) + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestEmbeddingV1Async: + """Test asynchronous Embedding V1 methods""" + + async_test_cases = EMBEDDING_V1_TEST_CASES + + @pytest.mark.parametrize( + "test_case", async_test_cases, ids=[tc["name"] for tc in async_test_cases] + ) + @pytest.mark.asyncio + async def test_embedding_v1_async(self, test_case): + """Test asynchronous embedding v1 with various inputs""" + try: + result = await async_jigsaw.embedding(test_case["params"]) + assert result["success"] + assert "embeddings" in result + assert isinstance(result["embeddings"], list) + if "chunks" in result: + assert isinstance(result["chunks"], list) + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + @pytest.mark.parametrize( + "test_case", BLOB_TEST_CASES, ids=[tc["name"] for tc in BLOB_TEST_CASES] + ) + @pytest.mark.asyncio + async def test_embedding_v1_blob_async(self, test_case): + """Test asynchronous embedding v1 with blob inputs""" + try: + # Download blob content + blob_content = requests.get(test_case["blob_url"]).content + result = await async_jigsaw.embedding(blob_content, test_case["options"]) + assert result["success"] + assert "embeddings" in result + assert isinstance(result["embeddings"], list) + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestEmbeddingV2Sync: + """Test synchronous Embedding V2 methods""" + + sync_test_cases = EMBEDDING_V2_TEST_CASES + + @pytest.mark.parametrize( + "test_case", sync_test_cases, ids=[tc["name"] for tc in sync_test_cases] + ) + def test_embedding_v2(self, test_case): + """Test synchronous embedding v2 with various inputs""" + try: + result = jigsaw.embeddingV2(test_case["params"]) + assert result["success"] + assert "embeddings" in result + assert isinstance(result["embeddings"], list) + + # Check for chunks when chunking mode is used + if test_case["params"].get("token_overflow_mode") == "error": + assert "chunks" in result + assert isinstance(result["chunks"], list) + + # Check for speaker embeddings when speaker fingerprint is requested + if test_case["params"].get("speaker_fingerprint"): + assert "speaker_embeddings" in result + assert isinstance(result["speaker_embeddings"], list) + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + @pytest.mark.parametrize( + "test_case", BLOB_TEST_CASES, ids=[tc["name"] for tc in BLOB_TEST_CASES] + ) + def test_embedding_v2_blob(self, test_case): + """Test synchronous embedding v2 with blob inputs""" + try: + # Download blob content + blob_content = requests.get(test_case["blob_url"]).content + result = jigsaw.embeddingV2(blob_content, test_case["options"]) + assert result["success"] + assert "embeddings" in result + assert isinstance(result["embeddings"], list) + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestEmbeddingV2Async: + """Test asynchronous Embedding V2 methods""" + + async_test_cases = EMBEDDING_V2_TEST_CASES + + @pytest.mark.parametrize( + "test_case", async_test_cases, ids=[tc["name"] for tc in async_test_cases] + ) + @pytest.mark.asyncio + async def test_embedding_v2_async(self, test_case): + """Test asynchronous embedding v2 with various inputs""" + try: + result = await async_jigsaw.embeddingV2(test_case["params"]) + assert result["success"] + assert "embeddings" in result + assert isinstance(result["embeddings"], list) + + # Check for chunks when chunking mode is used + if test_case["params"].get("token_overflow_mode") == "error": + assert "chunks" in result + assert isinstance(result["chunks"], list) + + # Check for speaker embeddings when speaker fingerprint is requested + if test_case["params"].get("speaker_fingerprint"): + assert "speaker_embeddings" in result + assert isinstance(result["speaker_embeddings"], list) + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + @pytest.mark.parametrize( + "test_case", BLOB_TEST_CASES, ids=[tc["name"] for tc in BLOB_TEST_CASES] + ) + @pytest.mark.asyncio + async def test_embedding_v2_blob_async(self, test_case): + """Test asynchronous embedding v2 with blob inputs""" + try: + # Download blob content + blob_content = requests.get(test_case["blob_url"]).content + result = await async_jigsaw.embeddingV2(blob_content, test_case["options"]) + assert result["success"] + assert "embeddings" in result + assert isinstance(result["embeddings"], list) + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") diff --git a/tests/test_embedding_async.py b/tests/test_embedding_async.py deleted file mode 100644 index bf2e1e6..0000000 --- a/tests/test_embedding_async.py +++ /dev/null @@ -1,23 +0,0 @@ -from unittest.mock import MagicMock -import unittest -from jigsawstack.exceptions import JigsawStackError -from jigsawstack import AsyncJigsawStack -import pytest -import asyncio -import logging - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_async_embedding_generation_response(): - async def _test(): - client = AsyncJigsawStack() - try: - result = await client.embedding({"text": "Hello, World!", "type": "text"}) - logger.info(result) - assert result["success"] == True - except JigsawStackError as e: - pytest.fail(f"Unexpected JigsawStackError: {e}") - - asyncio.run(_test()) diff --git a/tests/test_file_store.py b/tests/test_file_store.py index daef198..97d07dd 100644 --- a/tests/test_file_store.py +++ b/tests/test_file_store.py @@ -1,64 +1,157 @@ -from unittest.mock import MagicMock -import unittest -from jigsawstack.exceptions import JigsawStackError -from jigsawstack import JigsawStack +import logging +import os +import uuid import pytest +import requests +from dotenv import load_dotenv + +import jigsawstack +from jigsawstack.exceptions import JigsawStackError -# flake8: noqa +load_dotenv() +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) -client = JigsawStack() +jigsaw = jigsawstack.JigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) +async_jigsaw = jigsawstack.AsyncJigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) +TEXT_FILE_CONTENT = b"This is a test file content for JigsawStack storage" +JSON_FILE_CONTENT = b'{"test": "data", "key": "value"}' +BINARY_FILE_CONTENT = requests.get( + "https://rogilvkqloanxtvjfrkm.supabase.co/storage/v1/object/public/demo/Collabo%201080x842.jpg" +).content -@pytest.mark.skip(reason="Skipping TestStoreAPI class for now") -class TestStoreAPI(unittest.TestCase): - def test_upload_success_response(self) -> None: - # Sample file content as bytes - file_content = b"This is a test file content" - options = { - "key": "test-file.txt", +TEST_CASES_UPLOAD = [ + { + "name": "upload_text_file_with_key", + "file": TEXT_FILE_CONTENT, + "options": { + "key": "sample_file.txt", "content_type": "text/plain", "overwrite": True, + }, + }, + { + "name": "upload_image_with_temp_url", + "file": BINARY_FILE_CONTENT, + "options": { + "key": "test_image.jpg", + "content_type": "image/jpeg", + "overwrite": True, "temp_public_url": True, - } - try: - result = client.store.upload(file_content, options) - assert result["success"] == True - except JigsawStackError as e: - assert e.message == "Failed to parse API response. Please try again." + }, + }, + { + "name": "upload_binary_file", + "file": BINARY_FILE_CONTENT, + "options": { + "overwrite": True, + }, + }, + { + "name": "upload_file_no_options", + "file": TEXT_FILE_CONTENT, + "options": None, + }, +] - def test_get_success_response(self) -> None: - key = "test-file.txt" + +class TestFileStoreSync: + """Test synchronous file store operations""" + + uploaded_keys = [] # Track uploaded files for cleanup + + @pytest.mark.parametrize( + "test_case", TEST_CASES_UPLOAD, ids=[tc["name"] for tc in TEST_CASES_UPLOAD] + ) + def test_file_upload(self, test_case): + """Test synchronous file upload with various options""" try: - result = client.store.get(key) - # For file retrieval, we expect the actual file content - assert result is not None + result = jigsaw.store.upload(test_case["file"], test_case["options"]) + + print(f"Upload test {test_case['name']}: {result}") + assert result.get("key") is not None + assert result.get("url") is not None + assert result.get("size") > 0 + + # Check temp_public_url if requested + if test_case.get("options") and test_case["options"].get("temp_public_url"): + assert result.get("temp_public_url") is not None + + # Store key for cleanup + self.uploaded_keys.append(result["key"]) + except JigsawStackError as e: - assert e.message == "Failed to parse API response. Please try again." + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") - def test_delete_success_response(self) -> None: - key = "test-file.txt" + def test_file_get(self): + """Test synchronous file retrieval""" + # First upload a file to retrieve + test_key = f"test-get-{uuid.uuid4().hex[:8]}.txt" try: - result = client.store.delete(key) - assert result["success"] == True + upload_result = jigsaw.store.upload( + TEXT_FILE_CONTENT, {"key": test_key, "content_type": "text/plain"} + ) + + # Now retrieve it + file_content = jigsaw.store.get(upload_result["key"]) + assert file_content is not None + print(f"Retrieved file with key {upload_result['key']}") + + # Cleanup + self.uploaded_keys.append(upload_result["key"]) + except JigsawStackError as e: - assert e.message == "Failed to parse API response. Please try again." + pytest.fail(f"Unexpected JigsawStackError in file get: {e}") + - def test_upload_without_options_success_response(self) -> None: - # Test upload without optional parameters - file_content = b"This is another test file content" +class TestFileStoreAsync: + """Test asynchronous file store operations""" + + uploaded_keys = [] # Track uploaded files for cleanup + + @pytest.mark.parametrize( + "test_case", TEST_CASES_UPLOAD, ids=[tc["name"] for tc in TEST_CASES_UPLOAD] + ) + @pytest.mark.asyncio + async def test_file_upload_async(self, test_case): + """Test asynchronous file upload with various options""" try: - result = client.store.upload(file_content) - assert result["success"] == True + result = await async_jigsaw.store.upload(test_case["file"], test_case["options"]) + + print(f"Async upload test {test_case['name']}: {result}") + assert result.get("key") is not None + assert result.get("url") is not None + assert result.get("size") > 0 + + # Check temp_public_url if requested + if test_case.get("options") and test_case["options"].get("temp_public_url"): + assert result.get("temp_public_url") is not None + + # Store key for cleanup + self.uploaded_keys.append(result["key"]) + except JigsawStackError as e: - assert e.message == "Failed to parse API response. Please try again." + pytest.fail(f"Unexpected JigsawStackError in async {test_case['name']}: {e}") - def test_upload_with_partial_options_success_response(self) -> None: - # Test upload with partial options - file_content = b"This is a test file with partial options" - options = {"key": "partial-test-file.txt", "overwrite": False} + @pytest.mark.asyncio + async def test_file_get_async(self): + """Test asynchronous file retrieval""" + # First upload a file to retrieve + test_key = f"test-async-get-{uuid.uuid4().hex[:8]}.txt" try: - result = client.store.upload(file_content, options) - assert result["success"] == True + upload_result = await async_jigsaw.store.upload( + TEXT_FILE_CONTENT, {"key": test_key, "content_type": "text/plain"} + ) + + # Now retrieve it + file_content = await async_jigsaw.store.get(upload_result["key"]) + assert file_content is not None + print(f"Async retrieved file with key {upload_result['key']}") + + # Cleanup + self.uploaded_keys.append(upload_result["key"]) + except JigsawStackError as e: - assert e.message == "Failed to parse API response. Please try again." + pytest.fail(f"Unexpected JigsawStackError in async file get: {e}") diff --git a/tests/test_geo.py b/tests/test_geo.py deleted file mode 100644 index e97e3fb..0000000 --- a/tests/test_geo.py +++ /dev/null @@ -1,38 +0,0 @@ -from unittest.mock import MagicMock -import unittest -from jigsawstack.exceptions import JigsawStackError -from jigsawstack import AsyncJigsawStack -import pytest -import asyncio -import logging - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -@pytest.mark.skip(reason="Skipping TestWebAPI class for now") -def test_async_country_response(): - async def _test(): - client = AsyncJigsawStack() - try: - result = await client.geo.country({"country_code": "SGP"}) - logger.info(result) - assert result["success"] == True - except JigsawStackError as e: - pytest.fail(f"Unexpected JigsawStackError: {e}") - - asyncio.run(_test()) - - -@pytest.mark.skip(reason="Skipping TestWebAPI class for now") -def test_async_search_response(): - async def _test(): - client = AsyncJigsawStack() - try: - result = await client.geo.search({"search_value": "Nigeria"}) - logger.info(result) - assert result["success"] == True - except JigsawStackError as e: - pytest.fail(f"Unexpected JigsawStackError: {e}") - - asyncio.run(_test()) diff --git a/tests/test_image_generation.py b/tests/test_image_generation.py index 6cf275a..6b982ba 100644 --- a/tests/test_image_generation.py +++ b/tests/test_image_generation.py @@ -1,57 +1,221 @@ -from unittest.mock import MagicMock -import unittest -from jigsawstack.exceptions import JigsawStackError -import jigsawstack -import pytest -import asyncio import logging -import io +import os + +import pytest +import requests +from dotenv import load_dotenv + +import jigsawstack +from jigsawstack.exceptions import JigsawStackError +load_dotenv() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -jigsaw = jigsawstack.JigsawStack() -async_jigsaw = jigsawstack.AsyncJigsawStack() +jigsaw = jigsawstack.JigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) +async_jigsaw = jigsawstack.AsyncJigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) + +IMAGE_URL = "https://images.unsplash.com/photo-1494588024300-e9df7ff98d78?q=80&w=1284&auto=format&fit=crop&ixlib=rb-4.1.0&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" +FILE_STORE_KEY = jigsaw.store.upload( + requests.get(IMAGE_URL).content, + {"filename": "test_image.jpg", "content_type": "image/jpeg", "overwrite": True}, +) + +TEST_CASES = [ + { + "name": "basic_generation_with_prompt", + "params": { + "prompt": "A beautiful mountain landscape at sunset", + }, + }, + { + "name": "with_aspect_ratio", + "params": { + "prompt": "A serene lake with mountains in the background", + "aspect_ratio": "16:9", + }, + }, + { + "name": "with_custom_dimensions", + "params": {"prompt": "A futuristic city skyline", "width": 1024, "height": 768}, + }, + { + "name": "with_output_format_png", + "params": {"prompt": "A colorful abstract painting", "output_format": "png"}, + }, + { + "name": "with_advanced_config", + "params": { + "prompt": "A realistic portrait of a person", + "advance_config": { + "negative_prompt": "blurry, low quality, distorted", + "guidance": 7, + "seed": 42, + }, + }, + }, + { + "name": "with_steps", + "params": { + "prompt": "A detailed botanical illustration", + "steps": 30, + "aspect_ratio": "3:4", + "return_type": "base64", + }, + }, + { + "name": "with_return_type_url", + "params": {"prompt": "A vintage car on a desert road", "return_type": "url"}, + }, + { + "name": "with_return_type_base64", + "params": {"prompt": "A fantasy castle on a hill", "return_type": "base64"}, + }, + { + "name": "with_all_options", + "params": { + "prompt": "An intricate steampunk clockwork mechanism", + "aspect_ratio": "4:3", + "steps": 25, + "output_format": "png", + "advance_config": { + "negative_prompt": "simple, plain, boring", + "guidance": 8, + "seed": 12345, + }, + "return_type": "base64", + }, + }, +] + +# Test cases for image-to-image generation (using existing images as input) +IMAGE_TO_IMAGE_TEST_CASES = [ + { + "name": "with_url", + "params": { + "prompt": "Add snow effects to this image", + "url": IMAGE_URL, + "return_type": "base64", + }, + }, + { + "name": "with_file_store_key", + "params": { + "prompt": "Apply a cyberpunk style to this image", + "file_store_key": FILE_STORE_KEY, + }, + }, +] + + +class TestImageGenerationSync: + """Test synchronous image generation methods""" + + @pytest.mark.parametrize("test_case", TEST_CASES, ids=[tc["name"] for tc in TEST_CASES]) + def test_image_generation(self, test_case): + """Test synchronous image generation with various parameters""" + try: + result = jigsaw.image_generation(test_case["params"]) + + print(type(result)) + + if isinstance(result, dict): + print(result) + # Check response structure + assert result is not None + + if type(result) is dict: + # Check for image data based on return_type + if test_case["params"].get("return_type") == "url": + assert result.get("url") is not None + assert requests.get(result["url"]).status_code == 200 + assert isinstance(result["url"], str) + elif test_case["params"].get("return_type") == "base64": + assert result.get("url") is not None + elif test_case["params"].get("return_type") == "url": + assert result.get("url") is not None + assert requests.get(result["url"]).status_code == 200 + else: + assert isinstance(result, bytes) + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") -def test_image_generation_response(): - async def _test(): - client = jigsawstack.AsyncJigsawStack() + @pytest.mark.parametrize( + "test_case", + IMAGE_TO_IMAGE_TEST_CASES[:1], + ids=[tc["name"] for tc in IMAGE_TO_IMAGE_TEST_CASES[:1]], + ) + def test_image_to_image_generation(self, test_case): + """Test image-to-image generation with URL input""" try: - result = await client.image_generation( - { - "prompt": "A beautiful mountain landscape at sunset", - "aspect_ratio": "16:9", - } - ) - # Just check if we got some data back + result = jigsaw.image_generation(test_case["params"]) + + print(f"Test {test_case['name']}: Generated image from input") assert result is not None - assert len(result) > 0 + + if type(result) is dict: + assert result.get("success") + assert result.get("url") is not None + elif type(result) is bytes: + assert isinstance(result, bytes) + else: + pytest.fail(f"Unexpected result type in {test_case['name']}: {type(result)}") except JigsawStackError as e: - pytest.fail(f"Unexpected JigsawStackError: {e}") + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") - asyncio.run(_test()) +class TestImageGenerationAsync: + """Test asynchronous image generation methods""" -def test_image_generation_with_advanced_config(): - async def _test(): - client = jigsawstack.AsyncJigsawStack() + @pytest.mark.parametrize("test_case", TEST_CASES, ids=[tc["name"] for tc in TEST_CASES]) + @pytest.mark.asyncio + async def test_image_generation_async(self, test_case): + """Test asynchronous image generation with various parameters""" try: - result = await client.image_generation( - { - "prompt": "A beautiful mountain landscape at sunset", - "output_format": "png", - "advance_config": { - "negative_prompt": "blurry, low quality", - "guidance": 7, - "seed": 42, - }, - } - ) - # Just check if we got some data back + result = await async_jigsaw.image_generation(test_case["params"]) + + print(f"Async test {test_case['name']}: Generated image") + + # Check response structure assert result is not None - assert len(result) > 0 + if type(result) is dict: + # Check for image data based on return_type + if test_case["params"].get("return_type") == "url": + assert result.get("url") is not None + assert requests.get(result["url"]).status_code == 200 + assert isinstance(result["url"], str) + assert result["url"].startswith("http") + elif test_case["params"].get("return_type") == "base64": + assert result.get("url") is not None + elif test_case["params"].get("return_type") == "url": + assert result.get("url") is not None + assert requests.get(result["url"]).status_code == 200 + else: + assert isinstance(result, bytes) + except JigsawStackError as e: - pytest.fail(f"Unexpected JigsawStackError: {e}") + pytest.fail(f"Unexpected JigsawStackError in async {test_case['name']}: {e}") + + @pytest.mark.parametrize( + "test_case", + IMAGE_TO_IMAGE_TEST_CASES[:1], + ids=[tc["name"] for tc in IMAGE_TO_IMAGE_TEST_CASES[:1]], + ) + @pytest.mark.asyncio + async def test_image_to_image_generation_async(self, test_case): + """Test asynchronous image-to-image generation with URL input""" + try: + result = await async_jigsaw.image_generation(test_case["params"]) - asyncio.run(_test()) + assert result is not None + if type(result) is dict: + assert result.get("success") + assert result.get("url") is not None + elif type(result) is bytes: + assert isinstance(result, bytes) + else: + pytest.fail(f"Unexpected result type in {test_case['name']}: {type(result)}") + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in async {test_case['name']}: {e}") diff --git a/tests/test_object_detection.py b/tests/test_object_detection.py index 521189c..1fbd5ca 100644 --- a/tests/test_object_detection.py +++ b/tests/test_object_detection.py @@ -1,43 +1,143 @@ -from unittest.mock import MagicMock -import unittest -from jigsawstack.exceptions import JigsawStackError -import jigsawstack -import pytest -import asyncio import logging +import os + +import pytest +import requests +from dotenv import load_dotenv + +import jigsawstack +from jigsawstack.exceptions import JigsawStackError + +load_dotenv() + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -jigsaw = jigsawstack.JigsawStack() -async_jigsaw = jigsawstack.AsyncJigsawStack() +jigsaw = jigsawstack.JigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) +async_jigsaw = jigsawstack.AsyncJigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) -def test_object_detection_response(): - try: - result = jigsaw.vision.object_detection( - { - "url": "https://rogilvkqloanxtvjfrkm.supabase.co/storage/v1/object/public/demo/Collabo%201080x842.jpg" - } - ) - print(result) - assert result["success"] == True - except JigsawStackError as e: - pytest.fail(f"Unexpected JigsawStackError: {e}") +IMAGE_URL = ( + "https://rogilvkqloanxtvjfrkm.supabase.co/storage/v1/object/public/demo/Collabo%201080x842.jpg" +) +TEST_CASES = [ + { + "name": "with_url_only", + "params": {"url": IMAGE_URL}, + "blob": None, + "options": None, + }, + { + "name": "with_blob_only", + "params": None, + "blob": IMAGE_URL, + "options": None, + }, + { + "name": "annotated_image_true", + "blob": IMAGE_URL, + "options": {"annotated_image": True}, + }, + { + "name": "with_annotated_image_false", + "blob": IMAGE_URL, + "options": {"annotated_image": False}, + }, + { + "name": "with_blob_both_features", + "blob": IMAGE_URL, + "options": { + "features": ["object_detection", "gui"], + "annotated_image": True, + "return_type": "url", + }, + }, + { + "name": "with_blob_gui_features", + "blob": IMAGE_URL, + "options": {"features": ["gui"], "annotated_image": False}, + }, + { + "name": "with_blob_object_detection_features", + "blob": IMAGE_URL, + "options": { + "features": ["object_detection"], + "annotated_image": True, + "return_type": "base64", + }, + }, + { + "name": "with_prompts", + "blob": IMAGE_URL, + "options": { + "prompts": ["castle", "tree"], + "annotated_image": True, + }, + }, + { + "name": "with_all_options", + "blob": IMAGE_URL, + "options": { + "features": ["object_detection", "gui"], + "prompts": ["car", "road", "tree"], + "annotated_image": True, + "return_type": "base64", + "return_masks": False, + }, + }, +] -def test_object_detection_response_async(): - async def _test(): - client = jigsawstack.AsyncJigsawStack() + +class TestObjectDetectionSync: + """Test synchronous object detection methods""" + + sync_test_cases = TEST_CASES + + @pytest.mark.parametrize( + "test_case", sync_test_cases, ids=[tc["name"] for tc in sync_test_cases] + ) + def test_object_detection(self, test_case): + """Test synchronous object detection with various inputs""" try: - result = await client.vision.object_detection( - { - "url": "https://rogilvkqloanxtvjfrkm.supabase.co/storage/v1/object/public/demo/Collabo%201080x842.jpg" - } - ) - print(result) - assert result["success"] == True + if test_case.get("blob"): + # Download blob content + blob_content = requests.get(test_case["blob"]).content + result = jigsaw.vision.object_detection(blob_content, test_case.get("options", {})) + else: + # Use params directly + result = jigsaw.vision.object_detection(test_case["params"]) + + print(f"Test {test_case['name']}: {result}") + assert result["success"] except JigsawStackError as e: - pytest.fail(f"Unexpected JigsawStackError: {e}") + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + - asyncio.run(_test()) +class TestObjectDetectionAsync: + """Test asynchronous object detection methods""" + + async_test_cases = TEST_CASES + + @pytest.mark.parametrize( + "test_case", async_test_cases, ids=[tc["name"] for tc in async_test_cases] + ) + @pytest.mark.asyncio + async def test_object_detection_async(self, test_case): + """Test asynchronous object detection with various inputs""" + try: + if test_case.get("blob"): + # Download blob content + blob_content = requests.get(test_case["blob"]).content + result = await async_jigsaw.vision.object_detection( + blob_content, test_case.get("options", {}) + ) + else: + # Use params directly + result = await async_jigsaw.vision.object_detection(test_case["params"]) + + print(f"Test {test_case['name']}: {result}") + assert result["success"] + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") diff --git a/tests/test_prediction.py b/tests/test_prediction.py new file mode 100644 index 0000000..a87ccab --- /dev/null +++ b/tests/test_prediction.py @@ -0,0 +1,185 @@ +import logging +import os +from datetime import datetime, timedelta + +import pytest +from dotenv import load_dotenv + +import jigsawstack +from jigsawstack.exceptions import JigsawStackError + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +jigsaw = jigsawstack.JigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) +async_jigsaw = jigsawstack.AsyncJigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) + + +def generate_dates(start_date, num_days): + dates = [] + for i in range(num_days): + date = start_date + timedelta(days=i) + dates.append(date.strftime("%Y-%m-%d %H:%M:%S")) + return dates + + +start = datetime(2024, 1, 1) +dates = generate_dates(start, 30) +dates = [str(date) for date in dates] + +TEST_CASES = [ + { + "name": "linear_growth_pattern", + "params": { + "dataset": [{"date": dates[i], "value": 100 + (i * 10)} for i in range(10)], + "steps": 5, + }, + }, + { + "name": "exponential_growth_pattern", + "params": { + "dataset": [{"date": dates[i], "value": 100 * (1.1**i)} for i in range(10)], + "steps": 3, + }, + }, + { + "name": "seasonal_pattern", + "params": { + "dataset": [{"date": dates[i], "value": 100 + (50 * (i % 7))} for i in range(21)], + "steps": 7, + }, + }, + { + "name": "single_step_prediction", + "params": { + "dataset": [{"date": dates[i], "value": 200 + (i * 5)} for i in range(15)], + "steps": 1, + }, + }, + { + "name": "large_dataset_prediction", + "params": { + "dataset": [{"date": dates[i], "value": 1000 + (i * 20)} for i in range(30)], + "steps": 10, + }, + }, + { + "name": "declining_trend", + "params": { + "dataset": [{"date": dates[i], "value": 500 - (i * 10)} for i in range(10)], + "steps": 5, + }, + }, + { + "name": "volatile_data", + "params": { + "dataset": [ + {"date": dates[0], "value": 100}, + {"date": dates[1], "value": 150}, + {"date": dates[2], "value": 80}, + {"date": dates[3], "value": 200}, + {"date": dates[4], "value": 120}, + {"date": dates[5], "value": 180}, + {"date": dates[6], "value": 90}, + {"date": dates[7], "value": 160}, + ], + "steps": 4, + }, + }, + { + "name": "constant_values", + "params": { + "dataset": [{"date": dates[i], "value": 100} for i in range(10)], + "steps": 3, + }, + }, + { + "name": "string_values_prediction", + "params": { + "dataset": [ + {"date": dates[0], "value": "33.4"}, + {"date": dates[1], "value": "33.6"}, + {"date": dates[2], "value": "33.6"}, + {"date": dates[3], "value": "33.0"}, + {"date": dates[4], "value": "265.0"}, + {"date": dates[5], "value": "80"}, + {"date": dates[6], "value": "90.45"}, + ], + "steps": 3, + }, + }, + { + "name": "minimal_dataset", + "params": { + "dataset": [ + {"date": dates[0], "value": 50}, + {"date": dates[1], "value": 60}, + {"date": dates[2], "value": 70}, + {"date": dates[3], "value": 80}, + {"date": dates[4], "value": 90}, + ], + "steps": 2, + }, + }, +] + + +class TestPredictionSync: + """Test synchronous prediction methods""" + + sync_test_cases = TEST_CASES + + @pytest.mark.parametrize( + "test_case", sync_test_cases, ids=[tc["name"] for tc in sync_test_cases] + ) + def test_prediction(self, test_case): + """Test synchronous prediction with various inputs""" + try: + result = jigsaw.prediction(test_case["params"]) + + assert result["success"] + assert "prediction" in result + assert isinstance(result["prediction"], list) + + # Verify the number of predictions matches the requested steps + assert len(result["prediction"]) == test_case["params"]["steps"] + + # Verify each prediction has the required fields + for prediction in result["prediction"]: + assert "date" in prediction + assert "value" in prediction + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestPredictionAsync: + """Test asynchronous prediction methods""" + + async_test_cases = TEST_CASES + + @pytest.mark.parametrize( + "test_case", async_test_cases, ids=[tc["name"] for tc in async_test_cases] + ) + @pytest.mark.asyncio + async def test_prediction_async(self, test_case): + """Test asynchronous prediction with various inputs""" + try: + result = await async_jigsaw.prediction(test_case["params"]) + + assert result["success"] + assert "prediction" in result + assert isinstance(result["prediction"], list) + + # Verify the number of predictions matches the requested steps + assert len(result["prediction"]) == test_case["params"]["steps"] + + # Verify each prediction has the required fields + for prediction in result["prediction"]: + assert "date" in prediction + assert "value" in prediction + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") diff --git a/tests/test_search.py b/tests/test_search.py deleted file mode 100644 index 1ee28f0..0000000 --- a/tests/test_search.py +++ /dev/null @@ -1,53 +0,0 @@ -from unittest.mock import MagicMock -import unittest -from jigsawstack.exceptions import JigsawStackError -import jigsawstack -import pytest -import asyncio -import logging - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -jigsaw = jigsawstack.JigsawStack() -async_jigsaw = jigsawstack.AsyncJigsawStack() - - -def test_search_suggestion_response(): - try: - result = jigsaw.web.search({"query": "Where is San Francisco"}) - assert result["success"] == True - except JigsawStackError as e: - pytest.fail(f"Unexpected JigsawStackError: {e}") - - -def test_ai_search_response(): - try: - result = jigsaw.web.search({"query": "Where is San Francisco"}) - assert result["success"] == True - except JigsawStackError as e: - pytest.fail(f"Unexpected JigsawStackError: {e}") - - -def test_search_suggestion_response_async(): - async def _test(): - client = jigsawstack.AsyncJigsawStack() - try: - result = await client.web.search({"query": "Where is San Francisco"}) - assert result["success"] == True - except JigsawStackError as e: - pytest.fail(f"Unexpected JigsawStackError: {e}") - - asyncio.run(_test()) - - -def test_ai_search_response_async(): - async def _test(): - client = jigsawstack.AsyncJigsawStack() - try: - result = await client.web.search({"query": "Where is San Francisco"}) - assert result["success"] == True - except JigsawStackError as e: - pytest.fail(f"Unexpected JigsawStackError: {e}") - - asyncio.run(_test()) diff --git a/tests/test_sentiment.py b/tests/test_sentiment.py index cd3c602..5bb5914 100644 --- a/tests/test_sentiment.py +++ b/tests/test_sentiment.py @@ -1,21 +1,136 @@ -from unittest.mock import MagicMock -import unittest -from jigsawstack.exceptions import JigsawStackError -import jigsawstack +import logging +import os import pytest +from dotenv import load_dotenv + +import jigsawstack +from jigsawstack.exceptions import JigsawStackError + +load_dotenv() -# flake8: noqa +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) -client = jigsawstack.JigsawStack() +jigsaw = jigsawstack.JigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) +async_jigsaw = jigsawstack.AsyncJigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) +TEST_CASES = [ + { + "name": "positive_sentiment_excited", + "params": { + "text": "I am so excited about this new product! It's absolutely amazing and I can't wait to use it every day." + }, + }, + { + "name": "negative_sentiment_disappointed", + "params": { + "text": "I'm really disappointed with this purchase. The quality is terrible and it broke after just one day." + }, + }, + { + "name": "neutral_sentiment_factual", + "params": {"text": "The meeting is scheduled for 3 PM tomorrow in conference room B."}, + }, + { + "name": "mixed_sentiment_paragraph", + "params": { + "text": "The product arrived on time which was great. However, the packaging was damaged. The item itself works fine, but the instructions were confusing." + }, + }, + { + "name": "positive_sentiment_love", + "params": { + "text": "I absolutely love this! Best purchase I've made all year. Highly recommend to everyone!" + }, + }, + { + "name": "negative_sentiment_angry", + "params": { + "text": "This is unacceptable! I want a refund immediately. Worst customer service ever!" + }, + }, + { + "name": "single_sentence_positive", + "params": {"text": "This made my day!"}, + }, + { + "name": "single_sentence_negative", + "params": {"text": "I hate this."}, + }, + { + "name": "complex_multi_sentence", + "params": { + "text": "The first part of the movie was boring and I almost fell asleep. But then it got really exciting! The ending was spectacular and now it's one of my favorites." + }, + }, + { + "name": "question_sentiment", + "params": {"text": "Why is this product so amazing? I can't believe how well it works!"}, + }, +] -@pytest.mark.skip(reason="Skipping TestWebAPI class for now") -class TestSentimentAPI(unittest.TestCase): - def test_sentiment_response_success(self) -> None: - params = {"text": "I am so excited"} + +class TestSentimentSync: + """Test synchronous sentiment analysis methods""" + + sync_test_cases = TEST_CASES + + @pytest.mark.parametrize( + "test_case", sync_test_cases, ids=[tc["name"] for tc in sync_test_cases] + ) + def test_sentiment_analysis(self, test_case): + """Test synchronous sentiment analysis with various inputs""" try: - result = client.sentiment(params) - assert result["success"] == True + result = jigsaw.sentiment(test_case["params"]) + + assert result["success"] + assert "sentiment" in result + assert "emotion" in result["sentiment"] + assert "sentiment" in result["sentiment"] + assert "score" in result["sentiment"] + + # Check if sentences analysis is included + if "sentences" in result["sentiment"]: + assert isinstance(result["sentiment"]["sentences"], list) + for sentence in result["sentiment"]["sentences"]: + assert "text" in sentence + assert "sentiment" in sentence + assert "emotion" in sentence + assert "score" in sentence + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestSentimentAsync: + """Test asynchronous sentiment analysis methods""" + + async_test_cases = TEST_CASES + + @pytest.mark.parametrize( + "test_case", async_test_cases, ids=[tc["name"] for tc in async_test_cases] + ) + @pytest.mark.asyncio + async def test_sentiment_analysis_async(self, test_case): + """Test asynchronous sentiment analysis with various inputs""" + try: + result = await async_jigsaw.sentiment(test_case["params"]) + + assert result["success"] + assert "sentiment" in result + assert "emotion" in result["sentiment"] + assert "sentiment" in result["sentiment"] + assert "score" in result["sentiment"] + + # Check if sentences analysis is included + if "sentences" in result["sentiment"]: + assert isinstance(result["sentiment"]["sentences"], list) + for sentence in result["sentiment"]["sentences"]: + assert "text" in sentence + assert "sentiment" in sentence + assert "emotion" in sentence + assert "score" in sentence + except JigsawStackError as e: - assert e.message == "Failed to parse API response. Please try again." + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") diff --git a/tests/test_sql.py b/tests/test_sql.py new file mode 100644 index 0000000..822ae18 --- /dev/null +++ b/tests/test_sql.py @@ -0,0 +1,270 @@ +import logging +import os + +import pytest +from dotenv import load_dotenv + +import jigsawstack +from jigsawstack.exceptions import JigsawStackError + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +jigsaw = jigsawstack.JigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) +async_jigsaw = jigsawstack.AsyncJigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) + +# Sample schemas for different databases +MYSQL_SCHEMA = """ +CREATE TABLE users ( + id INT PRIMARY KEY AUTO_INCREMENT, + username VARCHAR(255) NOT NULL, + email VARCHAR(255) UNIQUE NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE orders ( + id INT PRIMARY KEY AUTO_INCREMENT, + user_id INT, + product_name VARCHAR(255), + quantity INT, + price DECIMAL(10, 2), + order_date DATE, + FOREIGN KEY (user_id) REFERENCES users(id) +); +""" + +POSTGRESQL_SCHEMA = """ +CREATE TABLE employees ( + id SERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + department VARCHAR(50), + salary NUMERIC(10, 2), + hire_date DATE, + is_active BOOLEAN DEFAULT true +); + +CREATE TABLE departments ( + id SERIAL PRIMARY KEY, + name VARCHAR(50) UNIQUE NOT NULL, + budget NUMERIC(12, 2), + manager_id INTEGER REFERENCES employees(id) +); +""" + +SQLITE_SCHEMA = """ +CREATE TABLE products ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + category TEXT, + price REAL, + stock_quantity INTEGER DEFAULT 0 +); + +CREATE TABLE sales ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + product_id INTEGER, + quantity INTEGER, + sale_date TEXT, + total_amount REAL, + FOREIGN KEY (product_id) REFERENCES products(id) +); +""" + +TEST_CASES = [ + { + "name": "mysql_simple_select", + "params": { + "prompt": "Get all users from the users table", + "sql_schema": MYSQL_SCHEMA, + "database": "mysql", + }, + }, + { + "name": "mysql_join_query", + "params": { + "prompt": "Get all orders with user information for orders placed in the last 30 days", + "sql_schema": MYSQL_SCHEMA, + "database": "mysql", + }, + }, + { + "name": "mysql_aggregate_query", + "params": { + "prompt": "Calculate the total revenue per user", + "sql_schema": MYSQL_SCHEMA, + "database": "mysql", + }, + }, + { + "name": "postgresql_simple_select", + "params": { + "prompt": "Find all active employees", + "sql_schema": POSTGRESQL_SCHEMA, + "database": "postgresql", + }, + }, + { + "name": "postgresql_complex_join", + "params": { + "prompt": "Get all departments with their manager names and department budgets greater than 100000", + "sql_schema": POSTGRESQL_SCHEMA, + "database": "postgresql", + }, + }, + { + "name": "postgresql_window_function", + "params": { + "prompt": "Rank employees by salary within each department", + "sql_schema": POSTGRESQL_SCHEMA, + "database": "postgresql", + }, + }, + { + "name": "sqlite_simple_query", + "params": { + "prompt": "List all products in the electronics category", + "sql_schema": SQLITE_SCHEMA, + "database": "sqlite", + }, + }, + { + "name": "sqlite_aggregate_with_group", + "params": { + "prompt": "Calculate total sales amount for each product", + "sql_schema": SQLITE_SCHEMA, + "database": "sqlite", + }, + }, + { + "name": "default_database_type", + "params": { + "prompt": "Select all records from users table where email contains 'example.com'", + "sql_schema": MYSQL_SCHEMA, + # No database specified, should use default + }, + }, + { + "name": "complex_multi_table_query", + "params": { + "prompt": "Find users who have placed more than 5 orders with total value exceeding 1000", + "sql_schema": MYSQL_SCHEMA, + "database": "mysql", + }, + }, + { + "name": "insert_query", + "params": { + "prompt": "Insert a new user with username 'john_doe' and email 'john@example.com'", + "sql_schema": MYSQL_SCHEMA, + "database": "mysql", + }, + }, + { + "name": "update_query", + "params": { + "prompt": "Update the salary of all employees in the IT department by 10%", + "sql_schema": POSTGRESQL_SCHEMA, + "database": "postgresql", + }, + }, + { + "name": "delete_query", + "params": { + "prompt": "Delete all products with zero stock quantity", + "sql_schema": SQLITE_SCHEMA, + "database": "sqlite", + }, + }, + { + "name": "subquery_example", + "params": { + "prompt": "Find all users who have never placed an order", + "sql_schema": MYSQL_SCHEMA, + "database": "mysql", + }, + }, + { + "name": "date_filtering", + "params": { + "prompt": "Get all employees hired in the last year", + "sql_schema": POSTGRESQL_SCHEMA, + "database": "postgresql", + }, + }, +] + + +class TestSQLSync: + """Test synchronous SQL text-to-sql methods""" + + sync_test_cases = TEST_CASES + + @pytest.mark.parametrize( + "test_case", sync_test_cases, ids=[tc["name"] for tc in sync_test_cases] + ) + def test_text_to_sql(self, test_case): + """Test synchronous text-to-sql with various inputs""" + try: + result = jigsaw.text_to_sql(test_case["params"]) + + assert result["success"] + assert "sql" in result + assert isinstance(result["sql"], str) + assert len(result["sql"]) > 0 + + # Basic SQL validation - check if it contains SQL keywords + sql_lower = result["sql"].lower() + sql_keywords = [ + "select", + "insert", + "update", + "delete", + "create", + "alter", + "drop", + ] + assert any(keyword in sql_lower for keyword in sql_keywords), ( + "Generated SQL should contain valid SQL keywords" + ) + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestSQLAsync: + """Test asynchronous SQL text-to-sql methods""" + + async_test_cases = TEST_CASES + + @pytest.mark.parametrize( + "test_case", async_test_cases, ids=[tc["name"] for tc in async_test_cases] + ) + @pytest.mark.asyncio + async def test_text_to_sql_async(self, test_case): + """Test asynchronous text-to-sql with various inputs""" + try: + result = await async_jigsaw.text_to_sql(test_case["params"]) + + assert result["success"] + assert "sql" in result + assert isinstance(result["sql"], str) + assert len(result["sql"]) > 0 + + sql_lower = result["sql"].lower() + sql_keywords = [ + "select", + "insert", + "update", + "delete", + "create", + "alter", + "drop", + ] + assert any(keyword in sql_lower for keyword in sql_keywords), ( + "Generated SQL should contain valid SQL keywords" + ) + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") diff --git a/tests/test_store.py b/tests/test_store.py deleted file mode 100644 index 4d59ac7..0000000 --- a/tests/test_store.py +++ /dev/null @@ -1,51 +0,0 @@ -from unittest.mock import MagicMock -import unittest -from jigsawstack.exceptions import JigsawStackError -from jigsawstack import AsyncJigsawStack -import pytest -import asyncio -import logging - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -@pytest.mark.skip(reason="Skipping TestWebAPI class for now") -class TestAsyncFileOperations: - """ - Test class for async file operations. - Add your file operation tests here. - """ - - def test_async_file_upload(self): - # Template for future file upload tests - pass - - def test_async_file_retrieval(self): - # Template for future file retrieval tests - pass - - def test_async_file_deletion(self): - # Template for future file deletion tests - pass - - -# Example file upload test -# Uncomment and modify as needed -""" -def test_async_file_upload_example(): - async def _test(): - client = AsyncJigsawStack() - try: - file_content = b"test file content" - result = await client.store.upload( - file_content, - {"filename": "test.txt", "overwrite": True} - ) - logger.info(result) - assert result["success"] == True - except JigsawStackError as e: - pytest.fail(f"Unexpected JigsawStackError: {e}") - - asyncio.run(_test()) -""" diff --git a/tests/test_summary.py b/tests/test_summary.py new file mode 100644 index 0000000..ab79ea9 --- /dev/null +++ b/tests/test_summary.py @@ -0,0 +1,183 @@ +import logging +import os + +import pytest +from dotenv import load_dotenv + +import jigsawstack +from jigsawstack.exceptions import JigsawStackError + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +jigsaw = jigsawstack.JigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) +async_jigsaw = jigsawstack.AsyncJigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) + +LONG_TEXT = """ +Artificial Intelligence (AI) has become one of the most transformative technologies of the 21st century. +From healthcare to finance, transportation to entertainment, AI is reshaping industries and changing the way we live and work. +Machine learning algorithms can now diagnose diseases with remarkable accuracy, predict market trends, and even create art. +Natural language processing has enabled computers to understand and generate human language, leading to the development of sophisticated chatbots and virtual assistants. +Computer vision systems can identify objects, faces, and activities in images and videos with superhuman precision. +However, the rapid advancement of AI also raises important ethical questions about privacy, job displacement, and the potential for bias in algorithmic decision-making. +As we continue to develop more powerful AI systems, it's crucial that we consider their societal impact and work to ensure that the benefits of AI are distributed equitably. +The future of AI holds immense promise, but it will require careful planning, regulation, and collaboration between technologists, policymakers, and society at large to realize its full potential while mitigating its risks. +""" + +ARTICLE_URL = "https://en.wikipedia.org/wiki/Artificial_intelligence" +PDF_URL = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf" + +TEST_CASES = [ + { + "name": "text_summary_default", + "params": { + "text": LONG_TEXT, + }, + }, + { + "name": "text_summary_with_text_type", + "params": { + "text": LONG_TEXT, + "type": "text", + }, + }, + { + "name": "text_summary_with_points_type", + "params": { + "text": LONG_TEXT, + "type": "points", + }, + }, + { + "name": "text_summary_with_max_points", + "params": { + "text": LONG_TEXT, + "type": "points", + "max_points": 5, + }, + }, + { + "name": "text_summary_with_max_characters", + "params": { + "text": LONG_TEXT, + "type": "text", + "max_characters": 200, + }, + }, + { + "name": "short_text_summary", + "params": { + "text": "This is a short text that doesn't need much summarization.", + }, + }, + { + "name": "url_summary_default", + "params": { + "url": ARTICLE_URL, + }, + }, + { + "name": "url_summary_with_text_type", + "params": { + "url": ARTICLE_URL, + "type": "text", + }, + }, + { + "name": "url_summary_with_points_type", + "params": { + "url": ARTICLE_URL, + "type": "points", + "max_points": 7, + }, + }, + { + "name": "pdf_url_summary", + "params": { + "url": PDF_URL, + "type": "text", + }, + }, + { + "name": "complex_text_with_points_and_limit", + "params": { + "text": LONG_TEXT * 3, # Triple the text for more content + "type": "points", + "max_points": 10, + }, + }, + { + "name": "technical_text_summary", + "params": { + "text": """ + Machine learning is a subset of artificial intelligence that focuses on the development of algorithms and statistical models that enable computer systems to improve their performance on a specific task through experience. + Deep learning, a subfield of machine learning, uses artificial neural networks with multiple layers to progressively extract higher-level features from raw input. + Supervised learning involves training models on labeled data, while unsupervised learning discovers patterns in unlabeled data. + Reinforcement learning enables agents to learn optimal behaviors through trial and error interactions with an environment. + """, + "type": "points", + "max_points": 4, + }, + }, +] + + +class TestSummarySync: + """Test synchronous summary methods""" + + sync_test_cases = TEST_CASES + + @pytest.mark.parametrize( + "test_case", sync_test_cases, ids=[tc["name"] for tc in sync_test_cases] + ) + def test_summary(self, test_case): + """Test synchronous summary with various inputs""" + try: + result = jigsaw.summary(test_case["params"]) + + assert result["success"] + assert "summary" in result + + if test_case["params"].get("type") == "points": + assert isinstance(result["summary"], list) + if "max_points" in test_case["params"]: + assert len(result["summary"]) <= test_case["params"]["max_points"] + else: + assert isinstance(result["summary"], str) + if "max_characters" in test_case["params"]: + assert len(result["summary"]) <= test_case["params"]["max_characters"] + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestSummaryAsync: + """Test asynchronous summary methods""" + + async_test_cases = TEST_CASES + + @pytest.mark.parametrize( + "test_case", async_test_cases, ids=[tc["name"] for tc in async_test_cases] + ) + @pytest.mark.asyncio + async def test_summary_async(self, test_case): + """Test asynchronous summary with various inputs""" + try: + result = await async_jigsaw.summary(test_case["params"]) + + assert result["success"] + assert "summary" in result + + if test_case["params"].get("type") == "points": + assert isinstance(result["summary"], list) + if "max_points" in test_case["params"]: + assert len(result["summary"]) <= test_case["params"]["max_points"] + else: + assert isinstance(result["summary"], str) + if "max_characters" in test_case["params"]: + assert len(result["summary"]) <= test_case["params"]["max_characters"] + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") diff --git a/tests/test_translate.py b/tests/test_translate.py new file mode 100644 index 0000000..5b560be --- /dev/null +++ b/tests/test_translate.py @@ -0,0 +1,238 @@ +import logging +import os + +import pytest +import requests +from dotenv import load_dotenv + +import jigsawstack +from jigsawstack.exceptions import JigsawStackError + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +jigsaw = jigsawstack.JigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) +async_jigsaw = jigsawstack.AsyncJigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) + +# Sample image URL for translation tests +IMAGE_URL = "https://images.unsplash.com/photo-1580679137870-86ef9f9a03d6?q=80&w=2574&auto=format&fit=crop&ixlib=rb-4.1.0&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + +# Text translation test cases +TEXT_TEST_CASES = [ + { + "name": "translate_single_text_to_spanish", + "params": { + "text": "Hello, how are you?", + "target_language": "es", + }, + }, + { + "name": "translate_single_text_with_current_language", + "params": { + "text": "Bonjour, comment allez-vous?", + "current_language": "fr", + "target_language": "en", + }, + }, + { + "name": "translate_multiple_texts", + "params": { + "text": ["Hello world", "Good morning", "Thank you"], + "target_language": "fr", + }, + }, + { + "name": "translate_to_german", + "params": { + "text": "The weather is beautiful today", + "target_language": "de", + }, + }, + { + "name": "translate_to_japanese", + "params": { + "text": "Welcome to our website", + "target_language": "ja", + }, + }, + { + "name": "translate_multiple_with_source_language", + "params": { + "text": ["Ciao", "Grazie", "Arrivederci"], + "current_language": "it", + "target_language": "en", + }, + }, +] + +# Image translation test cases +IMAGE_TEST_CASES = [ + { + "name": "translate_image_with_url", + "params": { + "url": IMAGE_URL, + "target_language": "es", + }, + "blob": None, + "options": None, + }, + { + "name": "translate_image_with_blob", + "params": None, + "blob": IMAGE_URL, + "options": { + "target_language": "fr", + }, + }, + { + "name": "translate_image_with_url_return_base64", + "params": { + "url": IMAGE_URL, + "target_language": "de", + "return_type": "base64", + }, + "blob": None, + "options": None, + }, + { + "name": "translate_image_with_blob_return_url", + "params": None, + "blob": IMAGE_URL, + "options": { + "target_language": "ja", + "return_type": "url", + }, + }, + { + "name": "translate_image_with_blob_return_binary", + "params": None, + "blob": IMAGE_URL, + "options": { + "target_language": "zh", + "return_type": "binary", + }, + }, + { + "name": "translate_image_to_italian", + "params": { + "url": IMAGE_URL, + "target_language": "it", + }, + "blob": None, + "options": None, + }, +] + + +class TestTranslateTextSync: + """Test synchronous text translation methods""" + + sync_test_cases = TEXT_TEST_CASES + + @pytest.mark.parametrize( + "test_case", sync_test_cases, ids=[tc["name"] for tc in sync_test_cases] + ) + def test_translate_text(self, test_case): + """Test synchronous text translation with various inputs""" + try: + result = jigsaw.translate.text(test_case["params"]) + assert result["success"] + assert "translated_text" in result + + # Check if the response structure matches the input + if isinstance(test_case["params"]["text"], list): + assert isinstance(result["translated_text"], list) + assert len(result["translated_text"]) == len(test_case["params"]["text"]) + else: + assert isinstance(result["translated_text"], str) + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestTranslateTextAsync: + """Test asynchronous text translation methods""" + + async_test_cases = TEXT_TEST_CASES + + @pytest.mark.parametrize( + "test_case", async_test_cases, ids=[tc["name"] for tc in async_test_cases] + ) + @pytest.mark.asyncio + async def test_translate_text_async(self, test_case): + """Test asynchronous text translation with various inputs""" + try: + result = await async_jigsaw.translate.text(test_case["params"]) + assert result["success"] + assert "translated_text" in result + + # Check if the response structure matches the input + if isinstance(test_case["params"]["text"], list): + assert isinstance(result["translated_text"], list) + assert len(result["translated_text"]) == len(test_case["params"]["text"]) + else: + assert isinstance(result["translated_text"], str) + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestTranslateImageSync: + """Test synchronous image translation methods""" + + sync_test_cases = IMAGE_TEST_CASES + + @pytest.mark.parametrize( + "test_case", sync_test_cases, ids=[tc["name"] for tc in sync_test_cases] + ) + def test_translate_image(self, test_case): + """Test synchronous image translation with various inputs""" + try: + if test_case.get("blob"): + # Download blob content + blob_content = requests.get(test_case["blob"]).content + result = jigsaw.translate.image(blob_content, test_case.get("options", {})) + else: + # Use params directly + result = jigsaw.translate.image(test_case["params"]) + assert result is not None + if isinstance(result, dict): + assert "url" in result + else: + assert isinstance(result, bytes) + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestTranslateImageAsync: + """Test asynchronous image translation methods""" + + async_test_cases = IMAGE_TEST_CASES + + @pytest.mark.parametrize( + "test_case", async_test_cases, ids=[tc["name"] for tc in async_test_cases] + ) + @pytest.mark.asyncio + async def test_translate_image_async(self, test_case): + """Test asynchronous image translation with various inputs""" + try: + if test_case.get("blob"): + # Download blob content + blob_content = requests.get(test_case["blob"]).content + result = await async_jigsaw.translate.image( + blob_content, test_case.get("options", {}) + ) + else: + # Use params directly + result = await async_jigsaw.translate.image(test_case["params"]) + assert result is not None + if isinstance(result, dict): + assert "url" in result + else: + assert isinstance(result, bytes) + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") diff --git a/tests/test_validate.py b/tests/test_validate.py index 51b8d3d..d0d2c43 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -1,42 +1,439 @@ -from unittest.mock import MagicMock -import unittest -from jigsawstack.exceptions import JigsawStackError -from jigsawstack import AsyncJigsawStack -import pytest -import asyncio import logging +import os + +import pytest +import requests +from dotenv import load_dotenv + +import jigsawstack +from jigsawstack.exceptions import JigsawStackError + +load_dotenv() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +jigsaw = jigsawstack.JigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) +async_jigsaw = jigsawstack.AsyncJigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) + +# Sample URLs for NSFW testing +SAFE_IMAGE_URL = "https://images.unsplash.com/photo-1506905925346-21bda4d32df4?q=80&w=2070" +POTENTIALLY_NSFW_URL = "https://images.unsplash.com/photo-1512310604669-443f26c35f52?q=80&w=868&auto=format&fit=crop&ixlib=rb-4.1.0&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + +SPAM_CHECK_TEST_CASES = [ + { + "name": "single_text_not_spam", + "params": { + "text": "I had a great experience with your product. The customer service was excellent!" + }, + }, + { + "name": "single_text_potential_spam", + "params": { + "text": "CLICK HERE NOW!!! FREE MONEY!!! Win $1000000 instantly! No credit card required! Act NOW!" + }, + }, + { + "name": "multiple_texts_mixed", + "params": { + "text": [ + "Thank you for your email. I'll get back to you soon.", + "BUY NOW! LIMITED TIME OFFER! 90% OFF EVERYTHING!", + "The meeting is scheduled for 3 PM tomorrow.", + ] + }, + }, + { + "name": "professional_email", + "params": { + "text": "Dear John, I hope this email finds you well. I wanted to follow up on our discussion from yesterday." + }, + }, + { + "name": "marketing_spam", + "params": { + "text": "Congratulations! You've been selected as our lucky winner! Claim your prize now at this link: bit.ly/win" + }, + }, +] + +# Spell Check Test Cases +SPELL_CHECK_TEST_CASES = [ + { + "name": "text_with_no_errors", + "params": {"text": "The quick brown fox jumps over the lazy dog."}, + }, + { + "name": "text_with_spelling_errors", + "params": {"text": "Thiss sentense has severel speling erors in it."}, + }, + { + "name": "text_with_language_code", + "params": {"text": "I recieved the pacakge yesterday.", "language_code": "en"}, + }, + { + "name": "mixed_correct_and_incorrect", + "params": {"text": "The weather is beatiful today, but tommorow might be diferent."}, + }, + { + "name": "technical_text", + "params": {"text": "The algorythm processes the datbase queries eficiently."}, + }, +] + +# Profanity Test Cases +PROFANITY_TEST_CASES = [ + { + "name": "clean_text", + "params": {"text": "This is a perfectly clean and professional message."}, + }, + { + "name": "text_with_profanity", + "params": { + "text": "This fucking thing is not working properly.", + "censor_replacement": "****", + }, + }, + { + "name": "text_with_custom_censor", + "params": { + "text": "What the fuck is going on here?", + "censor_replacement": "[CENSORED]", + }, + }, + { + "name": "mixed_clean_and_profane", + "params": {"text": "The weather is nice but this damn traffic is terrible."}, + }, + { + "name": "no_censor_replacement", + "params": {"text": "This text might contain some inappropriate words."}, + }, +] + +# NSFW Test Cases +NSFW_TEST_CASES = [ + { + "name": "safe_image_url", + "params": {"url": SAFE_IMAGE_URL}, + }, + { + "name": "landscape_image_url", + "params": {"url": POTENTIALLY_NSFW_URL}, + }, +] + +# NSFW Blob Test Cases +NSFW_BLOB_TEST_CASES = [ + { + "name": "safe_image_blob", + "blob_url": SAFE_IMAGE_URL, + "options": {}, + }, +] + + +class TestSpamCheckSync: + """Test synchronous spam check methods""" + + @pytest.mark.parametrize( + "test_case", + SPAM_CHECK_TEST_CASES, + ids=[tc["name"] for tc in SPAM_CHECK_TEST_CASES], + ) + def test_spam_check(self, test_case): + """Test synchronous spam check with various inputs""" + try: + result = jigsaw.validate.spamcheck(test_case["params"]) + + assert result["success"] + assert "check" in result + + # Check structure based on input type + if isinstance(test_case["params"]["text"], list): + assert isinstance(result["check"], list) + for check in result["check"]: + assert "is_spam" in check + assert "score" in check + assert isinstance(check["is_spam"], bool) + assert 0 <= check["score"] <= 1 + else: + assert "is_spam" in result["check"] + assert "score" in result["check"] + assert isinstance(result["check"]["is_spam"], bool) + assert 0 <= result["check"]["score"] <= 1 + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestSpellCheckSync: + """Test synchronous spell check methods""" + + @pytest.mark.parametrize( + "test_case", + SPELL_CHECK_TEST_CASES, + ids=[tc["name"] for tc in SPELL_CHECK_TEST_CASES], + ) + def test_spell_check(self, test_case): + """Test synchronous spell check with various inputs""" + try: + result = jigsaw.validate.spellcheck(test_case["params"]) + + assert result["success"] + assert "misspellings_found" in result + assert "misspellings" in result + assert "auto_correct_text" in result + assert isinstance(result["misspellings_found"], bool) + assert isinstance(result["misspellings"], list) + assert isinstance(result["auto_correct_text"], str) + + # Check misspellings structure + for misspelling in result["misspellings"]: + assert "word" in misspelling + assert "startIndex" in misspelling + assert "endIndex" in misspelling + assert "expected" in misspelling + assert "auto_corrected" in misspelling + assert isinstance(misspelling["expected"], list) + assert isinstance(misspelling["auto_corrected"], bool) + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + -@pytest.mark.skip(reason="Skipping TestWebAPI class for now") -def test_async_spam_check_response(): - async def _test(): - client = AsyncJigsawStack() +class TestProfanitySync: + """Test synchronous profanity check methods""" + + @pytest.mark.parametrize( + "test_case", + PROFANITY_TEST_CASES, + ids=[tc["name"] for tc in PROFANITY_TEST_CASES], + ) + def test_profanity_check(self, test_case): + """Test synchronous profanity check with various inputs""" try: - result = await client.validate.spamcheck({"text": "I am happy!"}) - logger.info(result) - assert result["success"] == True + result = jigsaw.validate.profanity(test_case["params"]) + + assert result["success"] + assert "clean_text" in result + assert "profanities" in result + assert "profanities_found" in result + assert isinstance(result["profanities_found"], bool) + assert isinstance(result["profanities"], list) + assert isinstance(result["clean_text"], str) + + # Check profanities structure + for profanity in result["profanities"]: + assert "profanity" in profanity + assert "startIndex" in profanity + assert "endIndex" in profanity + except JigsawStackError as e: - pytest.fail(f"Unexpected JigsawStackError: {e}") + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") - asyncio.run(_test()) +class TestNSFWSync: + """Test synchronous NSFW check methods""" -@pytest.mark.skip(reason="Skipping TestWebAPI class for now") -def test_async_spell_check_response(): - async def _test(): - client = AsyncJigsawStack() + @pytest.mark.parametrize( + "test_case", NSFW_TEST_CASES, ids=[tc["name"] for tc in NSFW_TEST_CASES] + ) + def test_nsfw_check(self, test_case): + """Test synchronous NSFW check with various inputs""" try: - result = await client.validate.spellcheck( - { - "text": "All the world's a stage, and all the men and women merely players. They have their exits and their entrances; And one man in his time plays many parts" - } - ) - logger.info(result) - assert result["success"] == True + result = jigsaw.validate.nsfw(test_case["params"]) + + assert result["success"] + assert "nsfw" in result + assert "nudity" in result + assert "gore" in result + assert "nsfw_score" in result + assert "nudity_score" in result + assert "gore_score" in result + + assert isinstance(result["nsfw"], bool) + assert isinstance(result["nudity"], bool) + assert isinstance(result["gore"], bool) + assert 0 <= result["nsfw_score"] <= 1 + assert 0 <= result["nudity_score"] <= 1 + assert 0 <= result["gore_score"] <= 1 + except JigsawStackError as e: - pytest.fail(f"Unexpected JigsawStackError: {e}") + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + @pytest.mark.parametrize( + "test_case", + NSFW_BLOB_TEST_CASES, + ids=[tc["name"] for tc in NSFW_BLOB_TEST_CASES], + ) + def test_nsfw_check_blob(self, test_case): + """Test synchronous NSFW check with blob inputs""" + try: + # Download blob content + blob_content = requests.get(test_case["blob_url"]).content + result = jigsaw.validate.nsfw(blob_content, test_case["options"]) + + assert result["success"] + assert "nsfw" in result + assert "nudity" in result + assert "gore" in result + assert "nsfw_score" in result + assert "nudity_score" in result + assert "gore_score" in result - asyncio.run(_test()) + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +# Async Test Classes + + +class TestSpamCheckAsync: + """Test asynchronous spam check methods""" + + @pytest.mark.parametrize( + "test_case", + SPAM_CHECK_TEST_CASES, + ids=[tc["name"] for tc in SPAM_CHECK_TEST_CASES], + ) + @pytest.mark.asyncio + async def test_spam_check_async(self, test_case): + """Test asynchronous spam check with various inputs""" + try: + result = await async_jigsaw.validate.spamcheck(test_case["params"]) + + assert result["success"] + assert "check" in result + + # Check structure based on input type + if isinstance(test_case["params"]["text"], list): + assert isinstance(result["check"], list) + for check in result["check"]: + assert "is_spam" in check + assert "score" in check + assert isinstance(check["is_spam"], bool) + assert 0 <= check["score"] <= 1 + else: + assert "is_spam" in result["check"] + assert "score" in result["check"] + assert isinstance(result["check"]["is_spam"], bool) + assert 0 <= result["check"]["score"] <= 1 + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestSpellCheckAsync: + """Test asynchronous spell check methods""" + + @pytest.mark.parametrize( + "test_case", + SPELL_CHECK_TEST_CASES, + ids=[tc["name"] for tc in SPELL_CHECK_TEST_CASES], + ) + @pytest.mark.asyncio + async def test_spell_check_async(self, test_case): + """Test asynchronous spell check with various inputs""" + try: + result = await async_jigsaw.validate.spellcheck(test_case["params"]) + + assert result["success"] + assert "misspellings_found" in result + assert "misspellings" in result + assert "auto_correct_text" in result + assert isinstance(result["misspellings_found"], bool) + assert isinstance(result["misspellings"], list) + assert isinstance(result["auto_correct_text"], str) + + # Check misspellings structure + for misspelling in result["misspellings"]: + assert "word" in misspelling + assert "startIndex" in misspelling + assert "endIndex" in misspelling + assert "expected" in misspelling + assert "auto_corrected" in misspelling + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestProfanityAsync: + """Test asynchronous profanity check methods""" + + @pytest.mark.parametrize( + "test_case", + PROFANITY_TEST_CASES, + ids=[tc["name"] for tc in PROFANITY_TEST_CASES], + ) + @pytest.mark.asyncio + async def test_profanity_check_async(self, test_case): + """Test asynchronous profanity check with various inputs""" + try: + result = await async_jigsaw.validate.profanity(test_case["params"]) + + assert result["success"] + assert "clean_text" in result + assert "profanities" in result + assert "profanities_found" in result + assert isinstance(result["profanities_found"], bool) + assert isinstance(result["profanities"], list) + assert isinstance(result["clean_text"], str) + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestNSFWAsync: + """Test asynchronous NSFW check methods""" + + @pytest.mark.parametrize( + "test_case", NSFW_TEST_CASES, ids=[tc["name"] for tc in NSFW_TEST_CASES] + ) + @pytest.mark.asyncio + async def test_nsfw_check_async(self, test_case): + """Test asynchronous NSFW check with various inputs""" + try: + result = await async_jigsaw.validate.nsfw(test_case["params"]) + + assert result["success"] + assert "nsfw" in result + assert "nudity" in result + assert "gore" in result + assert "nsfw_score" in result + assert "nudity_score" in result + assert "gore_score" in result + + assert isinstance(result["nsfw"], bool) + assert isinstance(result["nudity"], bool) + assert isinstance(result["gore"], bool) + assert 0 <= result["nsfw_score"] <= 1 + assert 0 <= result["nudity_score"] <= 1 + assert 0 <= result["gore_score"] <= 1 + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + @pytest.mark.parametrize( + "test_case", + NSFW_BLOB_TEST_CASES, + ids=[tc["name"] for tc in NSFW_BLOB_TEST_CASES], + ) + @pytest.mark.asyncio + async def test_nsfw_check_blob_async(self, test_case): + """Test asynchronous NSFW check with blob inputs""" + try: + # Download blob content + blob_content = requests.get(test_case["blob_url"]).content + result = await async_jigsaw.validate.nsfw(blob_content, test_case["options"]) + + assert result["success"] + assert "nsfw" in result + assert "nudity" in result + assert "gore" in result + assert "nsfw_score" in result + assert "nudity_score" in result + assert "gore_score" in result + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") diff --git a/tests/test_vision.py b/tests/test_vision.py deleted file mode 100644 index 7d8fcf0..0000000 --- a/tests/test_vision.py +++ /dev/null @@ -1,28 +0,0 @@ -from unittest.mock import MagicMock -import unittest -from jigsawstack.exceptions import JigsawStackError -from jigsawstack import AsyncJigsawStack -import pytest -import asyncio -import logging - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_async_vocr_response(): - async def _test(): - client = AsyncJigsawStack() - try: - result = await client.vision.vocr( - { - "url": "https://rogilvkqloanxtvjfrkm.supabase.co/storage/v1/object/public/demo/Collabo%201080x842.jpg?t=2024-03-22T09%3A22%3A48.442Z", - "prompt": ["Hello"], - } - ) - - assert result["success"] == True - except JigsawStackError as e: - pytest.fail(f"Unexpected JigsawStackError: {e}") - - asyncio.run(_test()) diff --git a/tests/test_web.py b/tests/test_web.py index 5191fca..c22ccd7 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -1,45 +1,317 @@ -from unittest.mock import MagicMock -import unittest from jigsawstack.exceptions import JigsawStackError -from jigsawstack import JigsawStack - +import jigsawstack import pytest +import logging +from dotenv import load_dotenv +import os + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +jigsaw = jigsawstack.JigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) +async_jigsaw = jigsawstack.AsyncJigsawStack(api_key=os.getenv("JIGSAWSTACK_API_KEY")) + +URL = "https://jigsawstack.com" + +# HTML to Any Test Cases +HTML_TO_ANY_TEST_CASES = [ + { + "name": "html_to_pdf_url", + "params": { + "url": URL, + "type": "pdf", + "return_type": "url", + }, + }, + { + "name": "html_to_png_base64", + "params": { + "url": URL, + "type": "png", + "return_type": "base64", + }, + }, + { + "name": "html_to_jpeg_binary", + "params": { + "url": URL, + "type": "jpeg", + "return_type": "binary", + }, + }, + { + "name": "html_string_to_pdf", + "params": { + "html": "

Test Document

This is a test.

", + "type": "pdf", + "return_type": "url", + }, + }, + { + "name": "html_to_pdf_with_options", + "params": { + "url": URL, + "type": "pdf", + "return_type": "url", + "pdf_display_header_footer": True, + "pdf_print_background": True, + }, + }, + { + "name": "html_to_png_full_page", + "params": { + "url": URL, + "type": "png", + "full_page": True, + "return_type": "url", + }, + }, + { + "name": "html_to_webp_custom_size", + "params": { + "url": URL, + "type": "webp", + "width": 1920, + "height": 1080, + "return_type": "base64", + }, + }, + { + "name": "html_to_png_mobile", + "params": { + "url": URL, + "type": "png", + "is_mobile": True, + "return_type": "url", + }, + }, + { + "name": "html_to_png_dark_mode", + "params": { + "url": URL, + "type": "png", + "dark_mode": True, + "return_type": "url", + }, + }, +] + +# Search Test Cases +SEARCH_TEST_CASES = [ + { + "name": "basic_search", + "params": { + "query": "artificial intelligence news", + }, + }, + { + "name": "search_specific_site", + "params": { + "query": "documentation site:github.com", + }, + }, + { + "name": "search_ai_mode", + "params": { + "query": "explain quantum computing", + "ai_overview": True, + }, + }, +] + +# Search Suggestions Test Cases +SEARCH_SUGGESTIONS_TEST_CASES = [ + { + "name": "basic_suggestions", + "params": { + "query": "machine learn", + }, + }, + { + "name": "programming_suggestions", + "params": { + "query": "python tutor", + }, + }, + { + "name": "partial_query_suggestions", + "params": { + "query": "artifi", + }, + }, +] + +class TestHTMLToAnySync: + """Test synchronous HTML to Any methods""" + + @pytest.mark.parametrize( + "test_case", + HTML_TO_ANY_TEST_CASES, + ids=[tc["name"] for tc in HTML_TO_ANY_TEST_CASES], + ) + def test_html_to_any(self, test_case): + """Test synchronous HTML to Any with various inputs""" + try: + result = jigsaw.web.html_to_any(test_case["params"]) + + return_type = test_case["params"].get("return_type", "url") + + if return_type == "binary": + assert isinstance(result, bytes) + assert len(result) > 0 + else: + assert result["success"] + assert "url" in result + assert isinstance(result["url"], str) + + if return_type == "base64": + # Check if it's a valid base64 string + assert result["url"].startswith("data:") + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestSearchSync: + """Test synchronous search methods""" + + @pytest.mark.parametrize( + "test_case", SEARCH_TEST_CASES, ids=[tc["name"] for tc in SEARCH_TEST_CASES] + ) + def test_search(self, test_case): + """Test synchronous search with various inputs""" + try: + result = jigsaw.web.search(test_case["params"]) + + assert result["success"] + assert "results" in result + assert isinstance(result["results"], list) + + if test_case["params"].get("max_results"): + assert len(result["results"]) <= test_case["params"]["max_results"] + + # Check result structure + for item in result["results"]: + assert "title" in item + assert "url" in item + assert "description" in item -# flake8: noqa + # Check AI mode response + if test_case["params"].get("ai"): + assert "ai_overview" in result + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") -client = JigsawStack() +class TestSearchSuggestionsSync: + """Test synchronous search suggestions methods""" -@pytest.mark.skip(reason="Skipping TestWebAPI class for now") -class TestWebAPI(unittest.TestCase): - def test_ai_scrape_success_response(self) -> None: - params = { - "url": "https://supabase.com/pricing", - "element_prompts": ["Plan title", "Plan price"], - } + @pytest.mark.parametrize( + "test_case", + SEARCH_SUGGESTIONS_TEST_CASES, + ids=[tc["name"] for tc in SEARCH_SUGGESTIONS_TEST_CASES], + ) + def test_search_suggestions(self, test_case): + """Test synchronous search suggestions with various inputs""" try: - result = client.file.upload(params) - assert result["success"] == True + result = jigsaw.web.search_suggestions(test_case["params"]) + + assert result["success"] + assert "suggestions" in result + assert isinstance(result["suggestions"], list) + assert len(result["suggestions"]) > 0 + except JigsawStackError as e: - assert e.message == "Failed to parse API response. Please try again." + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + +# Async Test Classes + +class TestHTMLToAnyAsync: + """Test asynchronous HTML to Any methods""" - def test_scrape_success_response(self) -> None: - params = { - "url": "https://supabase.com/pricing", - } + @pytest.mark.parametrize( + "test_case", + HTML_TO_ANY_TEST_CASES, + ids=[tc["name"] for tc in HTML_TO_ANY_TEST_CASES], + ) + @pytest.mark.asyncio + async def test_html_to_any_async(self, test_case): + """Test asynchronous HTML to Any with various inputs""" try: - result = client.web.scrape(params) - assert result["success"] == True + result = await async_jigsaw.web.html_to_any(test_case["params"]) + + return_type = test_case["params"].get("return_type", "url") + + if return_type == "binary": + assert isinstance(result, bytes) + assert len(result) > 0 + else: + assert result["success"] + assert "url" in result + assert isinstance(result["url"], str) + + if return_type == "base64": + # Check if it's a valid base64 string + assert result["url"].startswith("data:") + except JigsawStackError as e: - assert e.message == "Failed to parse API response. Please try again." + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + - def test_dns_success_response(self) -> None: +class TestSearchAsync: + """Test asynchronous search methods""" - params = { - "url": "https://supabase.com/pricing", - } + @pytest.mark.parametrize( + "test_case", SEARCH_TEST_CASES, ids=[tc["name"] for tc in SEARCH_TEST_CASES] + ) + @pytest.mark.asyncio + async def test_search_async(self, test_case): + """Test asynchronous search with various inputs""" try: - result = client.web.dns(params) - assert result["success"] == True + result = await async_jigsaw.web.search(test_case["params"]) + + assert result["success"] + assert "results" in result + assert isinstance(result["results"], list) + + if test_case["params"].get("max_results"): + assert len(result["results"]) <= test_case["params"]["max_results"] + + # Check result structure + for item in result["results"]: + assert "title" in item + assert "url" in item + assert "description" in item + + # Check AI mode response + if test_case["params"].get("ai_overview"): + assert "ai_overview" in result + + except JigsawStackError as e: + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}") + + +class TestSearchSuggestionsAsync: + """Test asynchronous search suggestions methods""" + + @pytest.mark.parametrize( + "test_case", + SEARCH_SUGGESTIONS_TEST_CASES, + ids=[tc["name"] for tc in SEARCH_SUGGESTIONS_TEST_CASES], + ) + @pytest.mark.asyncio + async def test_search_suggestions_async(self, test_case): + """Test asynchronous search suggestions with various inputs""" + try: + result = await async_jigsaw.web.search_suggestions(test_case["params"]) + + assert result["success"] + assert "suggestions" in result + assert isinstance(result["suggestions"], list) + assert len(result["suggestions"]) > 0 + except JigsawStackError as e: - assert e.message == "Failed to parse API response. Please try again." \ No newline at end of file + pytest.fail(f"Unexpected JigsawStackError in {test_case['name']}: {e}")