diff --git a/.gitignore b/.gitignore index bd39df3..7e9271c 100644 --- a/.gitignore +++ b/.gitignore @@ -21,4 +21,11 @@ test.py test_web.py .eggs/ -.conda/ \ No newline at end of file +.conda/ + +main.py +.python-version +pyproject.toml +uv.lock + +.ruff_cache/ \ No newline at end of file diff --git a/jigsawstack/__init__.py b/jigsawstack/__init__.py index bdb102e..8dafca5 100644 --- a/jigsawstack/__init__.py +++ b/jigsawstack/__init__.py @@ -15,6 +15,7 @@ from .embedding import Embedding, AsyncEmbedding from .exceptions import JigsawStackError from .image_generation import ImageGeneration, AsyncImageGeneration +from .classification import Classification, AsyncClassification class JigsawStack: @@ -25,6 +26,7 @@ class JigsawStack: web: Web search: Search prompt_engine: PromptEngine + classification: Classification api_key: str api_url: str disable_request_logging: bool @@ -118,6 +120,12 @@ def __init__( disable_request_logging=disable_request_logging, ).image_generation + self.classification = Classification( + api_key=api_key, + api_url=api_url, + disable_request_logging=disable_request_logging, + ) + class AsyncJigsawStack: @@ -229,6 +237,12 @@ def __init__( disable_request_logging=disable_request_logging, ).image_generation + self.classification = AsyncClassification( + api_key=api_key, + api_url=api_url, + disable_request_logging=disable_request_logging, + ) + # Create a global instance of the Web class diff --git a/jigsawstack/classification.py b/jigsawstack/classification.py new file mode 100644 index 0000000..69ed199 --- /dev/null +++ b/jigsawstack/classification.py @@ -0,0 +1,180 @@ +from typing import Any, Dict, List, Union, cast +from typing_extensions import NotRequired, TypedDict, Literal +from .request import Request, RequestConfig +from .async_request import AsyncRequest, AsyncRequestConfig +from ._config import ClientConfig + + +class DatasetItemText(TypedDict): + type: Literal["text"] + """ + Type of the dataset item: text + """ + + value: str + """ + Value of the dataset item + """ + + +class DatasetItemImage(TypedDict): + type: Literal["image"] + """ + Type of the dataset item: image + """ + + value: str + """ + Value of the dataset item + """ + + +class LabelItemText(TypedDict): + key: NotRequired[str] + """ + Optional key for the label + """ + + type: Literal["text"] + """ + Type of the label: text + """ + + value: str + """ + Value of the label + """ + + +class LabelItemImage(TypedDict): + key: NotRequired[str] + """ + Optional key for the label + """ + + type: Literal["image", "text"] + """ + Type of the label: image or text + """ + + value: str + """ + Value of the label + """ + + +class ClassificationTextParams(TypedDict): + dataset: List[DatasetItemText] + """ + List of text dataset items to classify + """ + + labels: List[LabelItemText] + """ + List of text labels for classification + """ + + multiple_labels: NotRequired[bool] + """ + Whether to allow multiple labels per item + """ + + +class ClassificationImageParams(TypedDict): + dataset: List[DatasetItemImage] + """ + List of image dataset items to classify + """ + + labels: List[LabelItemImage] + """ + List of labels for classification + """ + + multiple_labels: NotRequired[bool] + """ + Whether to allow multiple labels per item + """ + + +class ClassificationResponse(TypedDict): + predictions: List[Union[str, List[str]]] + """ + Classification predictions - single labels or multiple labels per item + """ + + + +class Classification(ClientConfig): + + config: RequestConfig + + def __init__( + self, + api_key: str, + api_url: str, + disable_request_logging: Union[bool, None] = False, + ): + super().__init__(api_key, api_url, disable_request_logging) + self.config = RequestConfig( + api_url=api_url, + api_key=api_key, + disable_request_logging=disable_request_logging, + ) + + def text(self, params: ClassificationTextParams) -> ClassificationResponse: + path = "/classification" + resp = Request( + config=self.config, + path=path, + params=cast(Dict[Any, Any], params), + verb="post", + ).perform_with_content() + return resp + def image(self, params: ClassificationImageParams) -> ClassificationResponse: + path = "/classification" + resp = Request( + config=self.config, + path=path, + params=cast(Dict[Any, Any], params), + verb="post", + ).perform_with_content() + return resp + + + +class AsyncClassification(ClientConfig): + config: AsyncRequestConfig + + def __init__( + self, + api_key: str, + api_url: str, + disable_request_logging: Union[bool, None] = False, + ): + super().__init__(api_key, api_url, disable_request_logging) + self.config = AsyncRequestConfig( + api_url=api_url, + api_key=api_key, + disable_request_logging=disable_request_logging, + ) + + async def text(self, params: ClassificationTextParams) -> ClassificationResponse: + path = "/classification" + resp = await AsyncRequest( + config=self.config, + path=path, + params=cast(Dict[Any, Any], params), + verb="post", + ).perform_with_content() + return resp + + async def image(self, params: ClassificationImageParams) -> ClassificationResponse: + path = "/classification" + resp = await AsyncRequest( + config=self.config, + path=path, + params=cast(Dict[Any, Any], params), + verb="post", + ).perform_with_content() + return resp \ No newline at end of file diff --git a/tests/test_classification.py b/tests/test_classification.py new file mode 100644 index 0000000..1f247c5 --- /dev/null +++ b/tests/test_classification.py @@ -0,0 +1,90 @@ +from jigsawstack.exceptions import JigsawStackError +from jigsawstack import JigsawStack + +import pytest + +# flake8: noqa + +client = JigsawStack() + + +@pytest.mark.parametrize("dataset,labels", [ + ( + [ + {"type": "text", "value": "I love programming"}, + {"type": "text", "value": "I love reading books"}, + {"type": "text", "value": "I love watching movies"}, + {"type": "text", "value": "I love playing games"}, + ], + [ + {"type": "text", "value": "programming"}, + {"type": "text", "value": "reading"}, + {"type": "text", "value": "watching"}, + {"type": "text", "value": "playing"}, + ] + ), + ( + [ + {"type": "text", "value": "This is awesome!"}, + {"type": "text", "value": "I hate this product"}, + {"type": "text", "value": "It's okay, nothing special"}, + ], + [ + {"type": "text", "value": "positive"}, + {"type": "text", "value": "negative"}, + {"type": "text", "value": "neutral"}, + ] + ), + ( + [ + {"type": "text", "value": "The weather is sunny today"}, + {"type": "text", "value": "It's raining heavily outside"}, + {"type": "text", "value": "Snow is falling gently"}, + ], + [ + {"type": "text", "value": "sunny"}, + {"type": "text", "value": "rainy"}, + {"type": "text", "value": "snowy"}, + ] + ), +]) +def test_classification_text_success_response(dataset, labels) -> None: + params = { + "dataset": dataset, + "labels": labels, + } + try: + result = client.classification.text(params) + print(result) + assert result["success"] == True + except JigsawStackError as e: + print(str(e)) + assert e.message == "Failed to parse API response. Please try again." + + +@pytest.mark.parametrize("dataset,labels", [ + ( + [ + {"type": "image", "value": "https://as2.ftcdn.net/v2/jpg/02/24/11/57/1000_F_224115780_2ssvcCoTfQrx68Qsl5NxtVIDFWKtAgq2.jpg"}, + {"type": "image", "value": "https://t3.ftcdn.net/jpg/02/95/44/22/240_F_295442295_OXsXOmLmqBUfZreTnGo9PREuAPSLQhff.jpg"}, + {"type": "image", "value": "https://as1.ftcdn.net/v2/jpg/05/54/94/46/1000_F_554944613_okdr3fBwcE9kTOgbLp4BrtVi8zcKFWdP.jpg"}, + ], + [ + {"type": "text", "value": "banana"}, + {"type": "image", "value": "https://upload.wikimedia.org/wikipedia/commons/8/8a/Banana-Single.jpg"}, + {"type": "text", "value": "kisses"}, + ] + ), +]) +def test_classification_image_success_response(dataset, labels) -> None: + params = { + "dataset": dataset, + "labels": labels, + } + try: + result = client.classification.image(params) + print(result) + assert result["success"] == True + except JigsawStackError as e: + print(str(e)) + assert e.message == "Failed to parse API response. Please try again." diff --git a/tests/test_file_store.py b/tests/test_file_store.py new file mode 100644 index 0000000..3f346d9 --- /dev/null +++ b/tests/test_file_store.py @@ -0,0 +1,67 @@ +from unittest.mock import MagicMock +import unittest +from jigsawstack.exceptions import JigsawStackError +from jigsawstack import JigsawStack + +import pytest + +# flake8: noqa + +client = JigsawStack() + + +@pytest.mark.skip(reason="Skipping TestStoreAPI class for now") +class TestStoreAPI(unittest.TestCase): + def test_upload_success_response(self) -> None: + # Sample file content as bytes + file_content = b"This is a test file content" + options = { + "key": "test-file.txt", + "content_type": "text/plain", + "overwrite": True, + "temp_public_url": True + } + try: + result = client.store.upload(file_content, options) + assert result["success"] == True + except JigsawStackError as e: + assert e.message == "Failed to parse API response. Please try again." + + def test_get_success_response(self) -> None: + key = "test-file.txt" + try: + result = client.store.get(key) + # For file retrieval, we expect the actual file content + assert result is not None + except JigsawStackError as e: + assert e.message == "Failed to parse API response. Please try again." + + def test_delete_success_response(self) -> None: + key = "test-file.txt" + try: + result = client.store.delete(key) + assert result["success"] == True + except JigsawStackError as e: + assert e.message == "Failed to parse API response. Please try again." + + def test_upload_without_options_success_response(self) -> None: + # Test upload without optional parameters + file_content = b"This is another test file content" + try: + result = client.store.upload(file_content) + assert result["success"] == True + except JigsawStackError as e: + assert e.message == "Failed to parse API response. Please try again." + + def test_upload_with_partial_options_success_response(self) -> None: + # Test upload with partial options + file_content = b"This is a test file with partial options" + options = { + "key": "partial-test-file.txt", + "overwrite": False + } + try: + result = client.store.upload(file_content, options) + assert result["success"] == True + except JigsawStackError as e: + assert e.message == "Failed to parse API response. Please try again."