diff --git a/_test_unstructured_client/test__decorators.py b/_test_unstructured_client/test__decorators.py new file mode 100644 index 00000000..e8781542 --- /dev/null +++ b/_test_unstructured_client/test__decorators.py @@ -0,0 +1,110 @@ +import pytest + +from unstructured_client import UnstructuredClient +from unstructured_client.models import shared +from unstructured_client.models.errors import SDKError + + +FAKE_KEY = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + + +@pytest.mark.parametrize( + "server_url", + [ + # -- well-formed url -- + "https://unstructured-000mock.api.unstructuredapp.io", + # -- common malformed urls -- + "unstructured-000mock.api.unstructuredapp.io", + "http://unstructured-000mock.api.unstructuredapp.io/general/v0/general", + "https://unstructured-000mock.api.unstructuredapp.io/general/v0/general", + "unstructured-000mock.api.unstructuredapp.io/general/v0/general", + ], +) +def test_clean_server_url_fixes_malformed_paid_api_url(server_url: str): + client = UnstructuredClient( + server_url=server_url, + api_key_auth=FAKE_KEY, + ) + assert ( + client.general.sdk_configuration.server_url + == "https://unstructured-000mock.api.unstructuredapp.io" + ) + + +@pytest.mark.parametrize( + "server_url", + [ + # -- well-formed url -- + "http://localhost:8000", + # -- common malformed urls -- + "localhost:8000", + "localhost:8000/general/v0/general", + "http://localhost:8000/general/v0/general", + ], +) +def test_clean_server_url_fixes_malformed_localhost_url(server_url: str): + client = UnstructuredClient( + server_url=server_url, + api_key_auth=FAKE_KEY, + ) + assert client.general.sdk_configuration.server_url == "http://localhost:8000" + + +def test_clean_server_url_returns_empty_string_given_empty_string(): + client = UnstructuredClient( server_url="", api_key_auth=FAKE_KEY) + assert client.general.sdk_configuration.server_url == "" + + +def test_clean_server_url_returns_None_given_no_server_url(): + client = UnstructuredClient( + api_key_auth=FAKE_KEY, + ) + assert client.general.sdk_configuration.server_url == None + + +@pytest.mark.parametrize( + "server_url", + [ + # -- well-formed url -- + "https://unstructured-000mock.api.unstructuredapp.io", + # -- malformed url -- + "unstructured-000mock.api.unstructuredapp.io/general/v0/general", + ], +) +def test_clean_server_url_fixes_malformed_urls_with_positional_arguments( + server_url: str, +): + client = UnstructuredClient( + FAKE_KEY, + "", + server_url, + ) + assert ( + client.general.sdk_configuration.server_url + == "https://unstructured-000mock.api.unstructuredapp.io" + ) + + +def test_suggest_defining_url_issues_a_warning_on_a_401(): + client = UnstructuredClient( + api_key_auth=FAKE_KEY, + ) + + filename = "_sample_docs/layout-parser-paper-fast.pdf" + + with open(filename, "rb") as f: + files = shared.Files( + content=f.read(), + file_name=filename, + ) + + req = shared.PartitionParameters( + files=files, + ) + + with pytest.raises(SDKError, match="API error occurred: Status 401"): + with pytest.warns( + UserWarning, + match="If intending to use the paid API, please define `server_url` in your request.", + ): + client.general.partition(req) diff --git a/_test_unstructured_client/test_check_url_protocol.py b/_test_unstructured_client/test_check_url_protocol.py deleted file mode 100644 index 2bc0b9fd..00000000 --- a/_test_unstructured_client/test_check_url_protocol.py +++ /dev/null @@ -1,70 +0,0 @@ -import os -import pytest - -from unstructured_client import UnstructuredClient - - -def get_api_key(): - api_key = os.getenv("UNS_API_KEY") - if api_key is None: - raise ValueError("""UNS_API_KEY environment variable not set. -Set it in your current shell session with `export UNS_API_KEY=`""") - return api_key - - -@pytest.mark.parametrize( - ("server_url"), - [ - ("https://unstructured-000mock.api.unstructuredapp.io"), # correct url - ("unstructured-000mock.api.unstructuredapp.io"), - ("http://unstructured-000mock.api.unstructuredapp.io/general/v0/general"), - ("https://unstructured-000mock.api.unstructuredapp.io/general/v0/general"), - ("unstructured-000mock.api.unstructuredapp.io/general/v0/general"), - ] -) -def test_clean_server_url_on_paid_api_url(server_url: str): - client = UnstructuredClient( - server_url=server_url, - api_key_auth=get_api_key(), - ) - assert client.general.sdk_configuration.server_url == "https://unstructured-000mock.api.unstructuredapp.io" - - -@pytest.mark.parametrize( - ("server_url"), - [ - ("http://localhost:8000"), # correct url - ("localhost:8000"), - ("localhost:8000/general/v0/general"), - ("http://localhost:8000/general/v0/general"), - ] -) -def test_clean_server_url_on_localhost(server_url: str): - client = UnstructuredClient( - server_url=server_url, - api_key_auth=get_api_key(), - ) - assert client.general.sdk_configuration.server_url == "http://localhost:8000" - - -def test_clean_server_url_on_empty_string(): - client = UnstructuredClient( - server_url="", - api_key_auth=get_api_key(), - ) - assert client.general.sdk_configuration.server_url == "" - -@pytest.mark.parametrize( - ("server_url"), - [ - ("https://unstructured-000mock.api.unstructuredapp.io"), - ("unstructured-000mock.api.unstructuredapp.io/general/v0/general"), - ] -) -def test_clean_server_url_with_positional_arguments(server_url: str): - client = UnstructuredClient( - get_api_key(), - "", - server_url, - ) - assert client.general.sdk_configuration.server_url == "https://unstructured-000mock.api.unstructuredapp.io" diff --git a/src/unstructured_client/general.py b/src/unstructured_client/general.py index cbd14b1c..ccaa85b7 100644 --- a/src/unstructured_client/general.py +++ b/src/unstructured_client/general.py @@ -4,6 +4,8 @@ from typing import Any, List, Optional from unstructured_client import utils from unstructured_client.models import errors, operations, shared +from unstructured_client.utils._decorators import suggest_defining_url_if_401 # human code + class General: sdk_configuration: SDKConfiguration @@ -12,7 +14,7 @@ def __init__(self, sdk_config: SDKConfiguration) -> None: self.sdk_configuration = sdk_config - + @suggest_defining_url_if_401 # human code def partition(self, request: Optional[shared.PartitionParameters], retries: Optional[utils.RetryConfig] = None) -> operations.PartitionResponse: r"""Pipeline 1""" base_url = utils.template_url(*self.sdk_configuration.get_server_details()) diff --git a/src/unstructured_client/sdk.py b/src/unstructured_client/sdk.py index 9358dbac..ff866c32 100644 --- a/src/unstructured_client/sdk.py +++ b/src/unstructured_client/sdk.py @@ -6,7 +6,7 @@ from typing import Callable, Dict, Union from unstructured_client import utils from unstructured_client.models import shared -from unstructured_client.utils._decorators import clean_server_url +from unstructured_client.utils._decorators import clean_server_url # human code class UnstructuredClient: r"""Unstructured Pipeline API: Partition documents with the Unstructured library""" @@ -14,7 +14,7 @@ class UnstructuredClient: sdk_configuration: SDKConfiguration - @clean_server_url + @clean_server_url # human code def __init__(self, api_key_auth: Union[str, Callable[[], str]], server: str = None, diff --git a/src/unstructured_client/utils/_decorators.py b/src/unstructured_client/utils/_decorators.py index fd891c2a..918725ce 100644 --- a/src/unstructured_client/utils/_decorators.py +++ b/src/unstructured_client/utils/_decorators.py @@ -1,15 +1,27 @@ from __future__ import annotations import functools -from typing import cast, Callable, Optional +from typing import cast, Callable, TYPE_CHECKING, Optional from typing_extensions import ParamSpec from urllib.parse import urlparse, urlunparse, ParseResult +import warnings + +from unstructured_client.models import errors, operations + +if TYPE_CHECKING: + from unstructured_client.general import General _P = ParamSpec("_P") def clean_server_url(func: Callable[_P, None]) -> Callable[_P, None]: + """A decorator for fixing common types of malformed 'server_url' arguments. + + This decorator addresses the common problem of users omitting or using the wrong url scheme + and/or adding the '/general/v0/general' path to the 'server_url'. The decorator should be + manually applied to the __init__ method of UnstructuredClient after merging a PR from Speakeasy. + """ @functools.wraps(func) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: @@ -39,8 +51,40 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: if url_is_in_kwargs: kwargs["server_url"] = urlunparse(cleaned_url) else: - args = args[:SERVER_URL_ARG_IDX] + (urlunparse(cleaned_url),) + args[SERVER_URL_ARG_IDX+1:] # type: ignore - + args = ( + args[:SERVER_URL_ARG_IDX] + + (urlunparse(cleaned_url),) + + args[SERVER_URL_ARG_IDX + 1 :] + ) # type: ignore + return func(*args, **kwargs) return wrapper + + +def suggest_defining_url_if_401( + func: Callable[_P, operations.PartitionResponse] +) -> Callable[_P, operations.PartitionResponse]: + """A decorator to suggest defining the 'server_url' parameter if a 401 Unauthorized error is + encountered. + + This decorator addresses the common problem of users not passing in the 'server_url' when + using their paid api key. The decorator should be manually applied to General.partition after + merging a PR from Speakeasy. + """ + + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> operations.PartitionResponse: + try: + return func(*args, **kwargs) + except errors.SDKError as e: + if e.status_code == 401: + general_obj: General = args[0] # type: ignore + if not general_obj.sdk_configuration.server_url: + warnings.warn( + "If intending to use the paid API, please define `server_url` in your request." + ) + + return func(*args, **kwargs) + + return wrapper