diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index c01974cd..0d588f46 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -5,6 +5,7 @@ from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.ai21_http_client.ai21_http_client import AI21HTTPClient +from ai21.clients.studio.client_url_parser import create_client_url from ai21.clients.studio.resources.beta.beta import Beta from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_chat import StudioChat @@ -43,11 +44,11 @@ def __init__( env_config: _AI21EnvConfig = AI21EnvConfig, **kwargs, ): - base_url = api_host or env_config.api_host + base_url = create_client_url(api_host or env_config.api_host) self._http_client = AI21HTTPClient( api_key=api_key or env_config.api_key, - base_url=f"{base_url}/studio/v1", + base_url=base_url, api_version=env_config.api_version, headers=headers, timeout_sec=timeout_sec or env_config.timeout_sec, diff --git a/ai21/clients/studio/async_ai21_client.py b/ai21/clients/studio/async_ai21_client.py index f46a16be..5ef59ce2 100644 --- a/ai21/clients/studio/async_ai21_client.py +++ b/ai21/clients/studio/async_ai21_client.py @@ -2,7 +2,8 @@ from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.ai21_http_client.async_ai21_http_client import AsyncAI21HTTPClient -from ai21.http_client.async_http_client import AsyncHttpClient +from ai21.clients.studio.client_url_parser import create_client_url +from ai21.clients.studio.resources.beta.async_beta import AsyncBeta from ai21.clients.studio.resources.studio_answer import AsyncStudioAnswer from ai21.clients.studio.resources.studio_chat import AsyncStudioChat from ai21.clients.studio.resources.studio_completion import AsyncStudioCompletion @@ -16,7 +17,7 @@ from ai21.clients.studio.resources.studio_segmentation import AsyncStudioSegmentation from ai21.clients.studio.resources.studio_summarize import AsyncStudioSummarize from ai21.clients.studio.resources.studio_summarize_by_segment import AsyncStudioSummarizeBySegment -from ai21.clients.studio.resources.beta.async_beta import AsyncBeta +from ai21.http_client.async_http_client import AsyncHttpClient class AsyncAI21Client: @@ -36,11 +37,11 @@ def __init__( env_config: _AI21EnvConfig = AI21EnvConfig, **kwargs, ): - base_url = api_host or env_config.api_host + base_url = create_client_url(api_host or env_config.api_host) self._http_client = AsyncAI21HTTPClient( api_key=api_key or env_config.api_key, - base_url=f"{base_url}/studio/v1", + base_url=base_url, api_version=env_config.api_version, headers=headers, timeout_sec=timeout_sec or env_config.timeout_sec, diff --git a/ai21/clients/studio/client_url_parser.py b/ai21/clients/studio/client_url_parser.py new file mode 100644 index 00000000..921e3c1a --- /dev/null +++ b/ai21/clients/studio/client_url_parser.py @@ -0,0 +1,10 @@ +from ai21.constants import STUDIO_HOST + + +def create_client_url(base_url: str) -> str: + allowed_urls = ["https://api-stage.ai21.com", STUDIO_HOST] + + if base_url in allowed_urls: + return f"{base_url}/studio/v1" + + return base_url diff --git a/tests/unittests/clients/studio/test_ai21_client.py b/tests/unittests/clients/studio/test_ai21_client.py new file mode 100644 index 00000000..15f55b58 --- /dev/null +++ b/tests/unittests/clients/studio/test_ai21_client.py @@ -0,0 +1,29 @@ +import pytest +from ai21 import AsyncAI21Client, AI21EnvConfig + + +@pytest.mark.asyncio +def test_async_ai21_client__when_pass_api_host__should_leave_as_is(): + base_url = "https://dont-modify-me.com" + client = AsyncAI21Client(api_host=base_url) + assert client._http_client._base_url == base_url + + +@pytest.mark.asyncio +def test_async_ai21_client__when_not_pass_api_host__should_add_suffix(): + client = AsyncAI21Client() + assert client._http_client._base_url == f"{AI21EnvConfig.api_host}/studio/v1" + + +@pytest.mark.asyncio +def test_async_ai21_client__when_pass_ai21_api_host__should_add_suffix(): + ai21_url = "https://api.ai21.com" + client = AsyncAI21Client(api_host=ai21_url) + assert client._http_client._base_url == f"{ai21_url}/studio/v1" + + +@pytest.mark.asyncio +def test_async_ai21_client__when_pass_ai21_with_suffix__should_not_modify(): + ai21_url = "https://api.ai21.com/studio/v1" + client = AsyncAI21Client(api_host=ai21_url) + assert client._http_client._base_url == ai21_url diff --git a/tests/unittests/clients/studio/test_async_ai21_client.py b/tests/unittests/clients/studio/test_async_ai21_client.py new file mode 100644 index 00000000..0c741489 --- /dev/null +++ b/tests/unittests/clients/studio/test_async_ai21_client.py @@ -0,0 +1,24 @@ +from ai21 import AI21Client, AI21EnvConfig + + +def test_ai21_client__when_pass_api_host__should_leave_as_is(): + base_url = "https://dont-modify-me.com" + client = AI21Client(api_host=base_url) + assert client._http_client._base_url == base_url + + +def test_ai21_client__when_not_pass_api_host__should_add_suffix(): + client = AI21Client() + assert client._http_client._base_url == f"{AI21EnvConfig.api_host}/studio/v1" + + +def test_ai21_client__when_pass_ai21_api_host__should_add_suffix(): + ai21_url = "https://api.ai21.com" + client = AI21Client(api_host=ai21_url) + assert client._http_client._base_url == f"{ai21_url}/studio/v1" + + +def test_ai21_client__when_pass_ai21_with_suffix__should_not_modify(): + ai21_url = "https://api.ai21.com/studio/v1" + client = AI21Client(api_host=ai21_url) + assert client._http_client._base_url == ai21_url