Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions _test_unstructured_client/test__decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import pytest

from unstructured_client import UnstructuredClient
from unstructured_client.models import shared
from unstructured_client.models.errors import SDKError


FAKE_KEY = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"


@pytest.mark.parametrize(
"server_url",
[
# -- well-formed url --
"https://unstructured-000mock.api.unstructuredapp.io",
# -- common malformed urls --
"unstructured-000mock.api.unstructuredapp.io",
"http://unstructured-000mock.api.unstructuredapp.io/general/v0/general",
"https://unstructured-000mock.api.unstructuredapp.io/general/v0/general",
"unstructured-000mock.api.unstructuredapp.io/general/v0/general",
],
)
def test_clean_server_url_fixes_malformed_paid_api_url(server_url: str):
client = UnstructuredClient(
server_url=server_url,
api_key_auth=FAKE_KEY,
)
assert (
client.general.sdk_configuration.server_url
== "https://unstructured-000mock.api.unstructuredapp.io"
)


@pytest.mark.parametrize(
"server_url",
[
# -- well-formed url --
"http://localhost:8000",
# -- common malformed urls --
"localhost:8000",
"localhost:8000/general/v0/general",
"http://localhost:8000/general/v0/general",
],
)
def test_clean_server_url_fixes_malformed_localhost_url(server_url: str):
client = UnstructuredClient(
server_url=server_url,
api_key_auth=FAKE_KEY,
)
assert client.general.sdk_configuration.server_url == "http://localhost:8000"


def test_clean_server_url_returns_empty_string_given_empty_string():
client = UnstructuredClient( server_url="", api_key_auth=FAKE_KEY)
assert client.general.sdk_configuration.server_url == ""


def test_clean_server_url_returns_None_given_no_server_url():
client = UnstructuredClient(
api_key_auth=FAKE_KEY,
)
assert client.general.sdk_configuration.server_url == None


@pytest.mark.parametrize(
"server_url",
[
# -- well-formed url --
"https://unstructured-000mock.api.unstructuredapp.io",
# -- malformed url --
"unstructured-000mock.api.unstructuredapp.io/general/v0/general",
],
)
def test_clean_server_url_fixes_malformed_urls_with_positional_arguments(
server_url: str,
):
client = UnstructuredClient(
FAKE_KEY,
"",
server_url,
)
assert (
client.general.sdk_configuration.server_url
== "https://unstructured-000mock.api.unstructuredapp.io"
)


def test_suggest_defining_url_issues_a_warning_on_a_401():
client = UnstructuredClient(
api_key_auth=FAKE_KEY,
)

filename = "_sample_docs/layout-parser-paper-fast.pdf"

with open(filename, "rb") as f:
files = shared.Files(
content=f.read(),
file_name=filename,
)

req = shared.PartitionParameters(
files=files,
)

with pytest.raises(SDKError, match="API error occurred: Status 401"):
with pytest.warns(
UserWarning,
match="If intending to use the paid API, please define `server_url` in your request.",
):
client.general.partition(req)
70 changes: 0 additions & 70 deletions _test_unstructured_client/test_check_url_protocol.py

This file was deleted.

4 changes: 3 additions & 1 deletion src/unstructured_client/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import Any, List, Optional
from unstructured_client import utils
from unstructured_client.models import errors, operations, shared
from unstructured_client.utils._decorators import suggest_defining_url_if_401 # human code


class General:
sdk_configuration: SDKConfiguration
Expand All @@ -12,7 +14,7 @@ def __init__(self, sdk_config: SDKConfiguration) -> None:
self.sdk_configuration = sdk_config



@suggest_defining_url_if_401 # human code
def partition(self, request: Optional[shared.PartitionParameters], retries: Optional[utils.RetryConfig] = None) -> operations.PartitionResponse:
r"""Pipeline 1"""
base_url = utils.template_url(*self.sdk_configuration.get_server_details())
Expand Down
4 changes: 2 additions & 2 deletions src/unstructured_client/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from typing import Callable, Dict, Union
from unstructured_client import utils
from unstructured_client.models import shared
from unstructured_client.utils._decorators import clean_server_url
from unstructured_client.utils._decorators import clean_server_url # human code

class UnstructuredClient:
r"""Unstructured Pipeline API: Partition documents with the Unstructured library"""
general: General

sdk_configuration: SDKConfiguration

@clean_server_url
@clean_server_url # human code
def __init__(self,
api_key_auth: Union[str, Callable[[], str]],
server: str = None,
Expand Down
50 changes: 47 additions & 3 deletions src/unstructured_client/utils/_decorators.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
from __future__ import annotations

import functools
from typing import cast, Callable, Optional
from typing import cast, Callable, TYPE_CHECKING, Optional
from typing_extensions import ParamSpec
from urllib.parse import urlparse, urlunparse, ParseResult
import warnings

from unstructured_client.models import errors, operations

if TYPE_CHECKING:
from unstructured_client.general import General


_P = ParamSpec("_P")


def clean_server_url(func: Callable[_P, None]) -> Callable[_P, None]:
"""A decorator for fixing common types of malformed 'server_url' arguments.

This decorator addresses the common problem of users omitting or using the wrong url scheme
and/or adding the '/general/v0/general' path to the 'server_url'. The decorator should be
manually applied to the __init__ method of UnstructuredClient after merging a PR from Speakeasy.
"""

@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
Expand Down Expand Up @@ -39,8 +51,40 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
if url_is_in_kwargs:
kwargs["server_url"] = urlunparse(cleaned_url)
else:
args = args[:SERVER_URL_ARG_IDX] + (urlunparse(cleaned_url),) + args[SERVER_URL_ARG_IDX+1:] # type: ignore

args = (
args[:SERVER_URL_ARG_IDX]
+ (urlunparse(cleaned_url),)
+ args[SERVER_URL_ARG_IDX + 1 :]
) # type: ignore

return func(*args, **kwargs)

return wrapper


def suggest_defining_url_if_401(
func: Callable[_P, operations.PartitionResponse]
) -> Callable[_P, operations.PartitionResponse]:
"""A decorator to suggest defining the 'server_url' parameter if a 401 Unauthorized error is
encountered.

This decorator addresses the common problem of users not passing in the 'server_url' when
using their paid api key. The decorator should be manually applied to General.partition after
merging a PR from Speakeasy.
"""

@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> operations.PartitionResponse:
try:
return func(*args, **kwargs)
except errors.SDKError as e:
if e.status_code == 401:
general_obj: General = args[0] # type: ignore
if not general_obj.sdk_configuration.server_url:
warnings.warn(
"If intending to use the paid API, please define `server_url` in your request."
)

return func(*args, **kwargs)

return wrapper