Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,11 @@ test.py
test_web.py

.eggs/
.conda/
.conda/

main.py
.python-version
pyproject.toml
uv.lock

.ruff_cache/
14 changes: 14 additions & 0 deletions jigsawstack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .embedding import Embedding, AsyncEmbedding
from .exceptions import JigsawStackError
from .image_generation import ImageGeneration, AsyncImageGeneration
from .classification import Classification, AsyncClassification


class JigsawStack:
Expand All @@ -25,6 +26,7 @@ class JigsawStack:
web: Web
search: Search
prompt_engine: PromptEngine
classification: Classification
api_key: str
api_url: str
disable_request_logging: bool
Expand Down Expand Up @@ -118,6 +120,12 @@ def __init__(
disable_request_logging=disable_request_logging,
).image_generation

self.classification = Classification(
api_key=api_key,
api_url=api_url,
disable_request_logging=disable_request_logging,
)



class AsyncJigsawStack:
Expand Down Expand Up @@ -229,6 +237,12 @@ def __init__(
disable_request_logging=disable_request_logging,
).image_generation

self.classification = AsyncClassification(
api_key=api_key,
api_url=api_url,
disable_request_logging=disable_request_logging,
)



# Create a global instance of the Web class
Expand Down
180 changes: 180 additions & 0 deletions jigsawstack/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from typing import Any, Dict, List, Union, cast
from typing_extensions import NotRequired, TypedDict, Literal
from .request import Request, RequestConfig
from .async_request import AsyncRequest, AsyncRequestConfig
from ._config import ClientConfig


class DatasetItemText(TypedDict):
type: Literal["text"]
"""
Type of the dataset item: text
"""

value: str
"""
Value of the dataset item
"""


class DatasetItemImage(TypedDict):
type: Literal["image"]
"""
Type of the dataset item: image
"""

value: str
"""
Value of the dataset item
"""


class LabelItemText(TypedDict):
key: NotRequired[str]
"""
Optional key for the label
"""

type: Literal["text"]
"""
Type of the label: text
"""

value: str
"""
Value of the label
"""


class LabelItemImage(TypedDict):
key: NotRequired[str]
"""
Optional key for the label
"""

type: Literal["image", "text"]
"""
Type of the label: image or text
"""

value: str
"""
Value of the label
"""


class ClassificationTextParams(TypedDict):
dataset: List[DatasetItemText]
"""
List of text dataset items to classify
"""

labels: List[LabelItemText]
"""
List of text labels for classification
"""

multiple_labels: NotRequired[bool]
"""
Whether to allow multiple labels per item
"""


class ClassificationImageParams(TypedDict):
dataset: List[DatasetItemImage]
"""
List of image dataset items to classify
"""

labels: List[LabelItemImage]
"""
List of labels for classification
"""

multiple_labels: NotRequired[bool]
"""
Whether to allow multiple labels per item
"""


class ClassificationResponse(TypedDict):
predictions: List[Union[str, List[str]]]
"""
Classification predictions - single labels or multiple labels per item
"""



class Classification(ClientConfig):

config: RequestConfig

def __init__(
self,
api_key: str,
api_url: str,
disable_request_logging: Union[bool, None] = False,
):
super().__init__(api_key, api_url, disable_request_logging)
self.config = RequestConfig(
api_url=api_url,
api_key=api_key,
disable_request_logging=disable_request_logging,
)

def text(self, params: ClassificationTextParams) -> ClassificationResponse:
path = "/classification"
resp = Request(
config=self.config,
path=path,
params=cast(Dict[Any, Any], params),
verb="post",
).perform_with_content()
return resp
def image(self, params: ClassificationImageParams) -> ClassificationResponse:
path = "/classification"
resp = Request(
config=self.config,
path=path,
params=cast(Dict[Any, Any], params),
verb="post",
).perform_with_content()
return resp



class AsyncClassification(ClientConfig):
config: AsyncRequestConfig

def __init__(
self,
api_key: str,
api_url: str,
disable_request_logging: Union[bool, None] = False,
):
super().__init__(api_key, api_url, disable_request_logging)
self.config = AsyncRequestConfig(
api_url=api_url,
api_key=api_key,
disable_request_logging=disable_request_logging,
)

async def text(self, params: ClassificationTextParams) -> ClassificationResponse:
path = "/classification"
resp = await AsyncRequest(
config=self.config,
path=path,
params=cast(Dict[Any, Any], params),
verb="post",
).perform_with_content()
return resp

async def image(self, params: ClassificationImageParams) -> ClassificationResponse:
path = "/classification"
resp = await AsyncRequest(
config=self.config,
path=path,
params=cast(Dict[Any, Any], params),
verb="post",
).perform_with_content()
return resp
90 changes: 90 additions & 0 deletions tests/test_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from jigsawstack.exceptions import JigsawStackError
from jigsawstack import JigsawStack

import pytest

# flake8: noqa

client = JigsawStack()


@pytest.mark.parametrize("dataset,labels", [
(
[
{"type": "text", "value": "I love programming"},
{"type": "text", "value": "I love reading books"},
{"type": "text", "value": "I love watching movies"},
{"type": "text", "value": "I love playing games"},
],
[
{"type": "text", "value": "programming"},
{"type": "text", "value": "reading"},
{"type": "text", "value": "watching"},
{"type": "text", "value": "playing"},
]
),
(
[
{"type": "text", "value": "This is awesome!"},
{"type": "text", "value": "I hate this product"},
{"type": "text", "value": "It's okay, nothing special"},
],
[
{"type": "text", "value": "positive"},
{"type": "text", "value": "negative"},
{"type": "text", "value": "neutral"},
]
),
(
[
{"type": "text", "value": "The weather is sunny today"},
{"type": "text", "value": "It's raining heavily outside"},
{"type": "text", "value": "Snow is falling gently"},
],
[
{"type": "text", "value": "sunny"},
{"type": "text", "value": "rainy"},
{"type": "text", "value": "snowy"},
]
),
])
def test_classification_text_success_response(dataset, labels) -> None:
params = {
"dataset": dataset,
"labels": labels,
}
try:
result = client.classification.text(params)
print(result)
assert result["success"] == True
except JigsawStackError as e:
print(str(e))
assert e.message == "Failed to parse API response. Please try again."


@pytest.mark.parametrize("dataset,labels", [
(
[
{"type": "image", "value": "https://as2.ftcdn.net/v2/jpg/02/24/11/57/1000_F_224115780_2ssvcCoTfQrx68Qsl5NxtVIDFWKtAgq2.jpg"},
{"type": "image", "value": "https://t3.ftcdn.net/jpg/02/95/44/22/240_F_295442295_OXsXOmLmqBUfZreTnGo9PREuAPSLQhff.jpg"},
{"type": "image", "value": "https://as1.ftcdn.net/v2/jpg/05/54/94/46/1000_F_554944613_okdr3fBwcE9kTOgbLp4BrtVi8zcKFWdP.jpg"},
],
[
{"type": "text", "value": "banana"},
{"type": "image", "value": "https://upload.wikimedia.org/wikipedia/commons/8/8a/Banana-Single.jpg"},
{"type": "text", "value": "kisses"},
]
),
])
def test_classification_image_success_response(dataset, labels) -> None:
params = {
"dataset": dataset,
"labels": labels,
}
try:
result = client.classification.image(params)
print(result)
assert result["success"] == True
except JigsawStackError as e:
print(str(e))
assert e.message == "Failed to parse API response. Please try again."
67 changes: 67 additions & 0 deletions tests/test_file_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from unittest.mock import MagicMock
import unittest
from jigsawstack.exceptions import JigsawStackError
from jigsawstack import JigsawStack

import pytest

# flake8: noqa

client = JigsawStack()


@pytest.mark.skip(reason="Skipping TestStoreAPI class for now")
class TestStoreAPI(unittest.TestCase):
def test_upload_success_response(self) -> None:
# Sample file content as bytes
file_content = b"This is a test file content"
options = {
"key": "test-file.txt",
"content_type": "text/plain",
"overwrite": True,
"temp_public_url": True
}
try:
result = client.store.upload(file_content, options)
assert result["success"] == True
except JigsawStackError as e:
assert e.message == "Failed to parse API response. Please try again."

def test_get_success_response(self) -> None:
key = "test-file.txt"
try:
result = client.store.get(key)
# For file retrieval, we expect the actual file content
assert result is not None
except JigsawStackError as e:
assert e.message == "Failed to parse API response. Please try again."

def test_delete_success_response(self) -> None:
key = "test-file.txt"
try:
result = client.store.delete(key)
assert result["success"] == True
except JigsawStackError as e:
assert e.message == "Failed to parse API response. Please try again."

def test_upload_without_options_success_response(self) -> None:
# Test upload without optional parameters
file_content = b"This is another test file content"
try:
result = client.store.upload(file_content)
assert result["success"] == True
except JigsawStackError as e:
assert e.message == "Failed to parse API response. Please try again."

def test_upload_with_partial_options_success_response(self) -> None:
# Test upload with partial options
file_content = b"This is a test file with partial options"
options = {
"key": "partial-test-file.txt",
"overwrite": False
}
try:
result = client.store.upload(file_content, options)
assert result["success"] == True
except JigsawStackError as e:
assert e.message == "Failed to parse API response. Please try again."