From 6c0a7a5fab73245b7e08ddbdbd8327c37b22256f Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Thu, 13 Jun 2024 16:09:32 +0300 Subject: [PATCH 1/5] fix: Chain /studio/v1 --- ai21/clients/studio/ai21_client.py | 4 ++-- tests/unittests/clients/studio/test_ai21_client.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) create mode 100644 tests/unittests/clients/studio/test_ai21_client.py diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index c01974cd..5b766ab8 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -43,11 +43,11 @@ def __init__( env_config: _AI21EnvConfig = AI21EnvConfig, **kwargs, ): - base_url = api_host or env_config.api_host + base_url = api_host or f"{env_config.api_host}/studio/v1" self._http_client = AI21HTTPClient( api_key=api_key or env_config.api_key, - base_url=f"{base_url}/studio/v1", + base_url=base_url, api_version=env_config.api_version, headers=headers, timeout_sec=timeout_sec or env_config.timeout_sec, diff --git a/tests/unittests/clients/studio/test_ai21_client.py b/tests/unittests/clients/studio/test_ai21_client.py new file mode 100644 index 00000000..865c5993 --- /dev/null +++ b/tests/unittests/clients/studio/test_ai21_client.py @@ -0,0 +1,12 @@ +from ai21 import AI21Client, AI21EnvConfig + + +def test_ai21_client__when_pass_api_host__should_leave_as_is(): + base_url = "https://dont-modify-me.com" + client = AI21Client(api_host=base_url) + assert client._http_client._base_url == base_url + + +def test_ai21_client__when_not_pass_api_host__should_add_suffix(): + client = AI21Client() + assert client._http_client._base_url == f"{AI21EnvConfig.api_host}/studio/v1" From 55aa27256be876437dd20963418263d9d65636e7 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Thu, 13 Jun 2024 16:12:13 +0300 Subject: [PATCH 2/5] fix: Async Client --- ai21/clients/studio/async_ai21_client.py | 4 ++-- tests/unittests/clients/studio/test_ai21_client.py | 13 ++++++++----- .../clients/studio/test_async_ai21_client.py | 12 ++++++++++++ 3 files changed, 22 insertions(+), 7 deletions(-) create mode 100644 tests/unittests/clients/studio/test_async_ai21_client.py diff --git a/ai21/clients/studio/async_ai21_client.py b/ai21/clients/studio/async_ai21_client.py index f46a16be..b59818a8 100644 --- a/ai21/clients/studio/async_ai21_client.py +++ b/ai21/clients/studio/async_ai21_client.py @@ -36,11 +36,11 @@ def __init__( env_config: _AI21EnvConfig = AI21EnvConfig, **kwargs, ): - base_url = api_host or env_config.api_host + base_url = api_host or f"{env_config.api_host}/studio/v1" self._http_client = AsyncAI21HTTPClient( api_key=api_key or env_config.api_key, - base_url=f"{base_url}/studio/v1", + base_url=base_url, api_version=env_config.api_version, headers=headers, timeout_sec=timeout_sec or env_config.timeout_sec, diff --git a/tests/unittests/clients/studio/test_ai21_client.py b/tests/unittests/clients/studio/test_ai21_client.py index 865c5993..5c42f12a 100644 --- a/tests/unittests/clients/studio/test_ai21_client.py +++ b/tests/unittests/clients/studio/test_ai21_client.py @@ -1,12 +1,15 @@ -from ai21 import AI21Client, AI21EnvConfig +import pytest +from ai21 import AsyncAI21Client, AI21EnvConfig -def test_ai21_client__when_pass_api_host__should_leave_as_is(): +@pytest.mark.asyncio +def test_async_ai21_client__when_pass_api_host__should_leave_as_is(): base_url = "https://dont-modify-me.com" - client = AI21Client(api_host=base_url) + client = AsyncAI21Client(api_host=base_url) assert client._http_client._base_url == base_url -def test_ai21_client__when_not_pass_api_host__should_add_suffix(): - client = AI21Client() +@pytest.mark.asyncio +def test_async_ai21_client__when_not_pass_api_host__should_add_suffix(): + client = AsyncAI21Client() assert client._http_client._base_url == f"{AI21EnvConfig.api_host}/studio/v1" diff --git a/tests/unittests/clients/studio/test_async_ai21_client.py b/tests/unittests/clients/studio/test_async_ai21_client.py new file mode 100644 index 00000000..865c5993 --- /dev/null +++ b/tests/unittests/clients/studio/test_async_ai21_client.py @@ -0,0 +1,12 @@ +from ai21 import AI21Client, AI21EnvConfig + + +def test_ai21_client__when_pass_api_host__should_leave_as_is(): + base_url = "https://dont-modify-me.com" + client = AI21Client(api_host=base_url) + assert client._http_client._base_url == base_url + + +def test_ai21_client__when_not_pass_api_host__should_add_suffix(): + client = AI21Client() + assert client._http_client._base_url == f"{AI21EnvConfig.api_host}/studio/v1" From 3fc726536a103795745e2a95a4de9d0c8e439468 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Thu, 13 Jun 2024 16:38:22 +0300 Subject: [PATCH 3/5] fix: Support urls --- ai21/clients/studio/ai21_client.py | 10 +++++++++- ai21/clients/studio/async_ai21_client.py | 10 +++++++++- ai21/constants.py | 1 + tests/unittests/clients/studio/test_ai21_client.py | 14 ++++++++++++++ .../clients/studio/test_async_ai21_client.py | 12 ++++++++++++ 5 files changed, 45 insertions(+), 2 deletions(-) diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index 5b766ab8..1930bbb3 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -43,7 +43,7 @@ def __init__( env_config: _AI21EnvConfig = AI21EnvConfig, **kwargs, ): - base_url = api_host or f"{env_config.api_host}/studio/v1" + base_url = self._create_url(api_host or env_config.api_host) self._http_client = AI21HTTPClient( api_key=api_key or env_config.api_key, @@ -70,6 +70,14 @@ def __init__( self.segmentation = StudioSegmentation(self._http_client) self.beta = Beta(self._http_client) + def _create_url(self, base_url: str) -> str: + allowed_urls = ["https://api-stage.ai21.com", "https://api.ai21.com"] + + if base_url in allowed_urls: + return f"{base_url}/studio/v1" + + return base_url + def count_tokens(self, text: str, tokenizer_name: str = PreTrainedTokenizers.J2_TOKENIZER) -> int: warnings.warn( "Please use the global get_tokenizer() method directly instead of the AI21Client().count_tokens() method.", diff --git a/ai21/clients/studio/async_ai21_client.py b/ai21/clients/studio/async_ai21_client.py index b59818a8..c67c7ee5 100644 --- a/ai21/clients/studio/async_ai21_client.py +++ b/ai21/clients/studio/async_ai21_client.py @@ -36,7 +36,7 @@ def __init__( env_config: _AI21EnvConfig = AI21EnvConfig, **kwargs, ): - base_url = api_host or f"{env_config.api_host}/studio/v1" + base_url = self._create_url(api_host or env_config.api_host) self._http_client = AsyncAI21HTTPClient( api_key=api_key or env_config.api_key, @@ -63,3 +63,11 @@ def __init__( self.library = AsyncStudioLibrary(self._http_client) self.segmentation = AsyncStudioSegmentation(self._http_client) self.beta = AsyncBeta(self._http_client) + + def _create_url(self, base_url: str) -> str: + allowed_urls = ["https://api-stage.ai21.com", "https://api.ai21.com"] + + if base_url in allowed_urls: + return f"{base_url}/studio/v1" + + return base_url diff --git a/ai21/constants.py b/ai21/constants.py index 49e416f8..5d97b082 100644 --- a/ai21/constants.py +++ b/ai21/constants.py @@ -1,2 +1,3 @@ DEFAULT_API_VERSION = "v1" STUDIO_HOST = "https://api.ai21.com" +"https://api-stage.ai21.com" diff --git a/tests/unittests/clients/studio/test_ai21_client.py b/tests/unittests/clients/studio/test_ai21_client.py index 5c42f12a..15f55b58 100644 --- a/tests/unittests/clients/studio/test_ai21_client.py +++ b/tests/unittests/clients/studio/test_ai21_client.py @@ -13,3 +13,17 @@ def test_async_ai21_client__when_pass_api_host__should_leave_as_is(): def test_async_ai21_client__when_not_pass_api_host__should_add_suffix(): client = AsyncAI21Client() assert client._http_client._base_url == f"{AI21EnvConfig.api_host}/studio/v1" + + +@pytest.mark.asyncio +def test_async_ai21_client__when_pass_ai21_api_host__should_add_suffix(): + ai21_url = "https://api.ai21.com" + client = AsyncAI21Client(api_host=ai21_url) + assert client._http_client._base_url == f"{ai21_url}/studio/v1" + + +@pytest.mark.asyncio +def test_async_ai21_client__when_pass_ai21_with_suffix__should_not_modify(): + ai21_url = "https://api.ai21.com/studio/v1" + client = AsyncAI21Client(api_host=ai21_url) + assert client._http_client._base_url == ai21_url diff --git a/tests/unittests/clients/studio/test_async_ai21_client.py b/tests/unittests/clients/studio/test_async_ai21_client.py index 865c5993..0c741489 100644 --- a/tests/unittests/clients/studio/test_async_ai21_client.py +++ b/tests/unittests/clients/studio/test_async_ai21_client.py @@ -10,3 +10,15 @@ def test_ai21_client__when_pass_api_host__should_leave_as_is(): def test_ai21_client__when_not_pass_api_host__should_add_suffix(): client = AI21Client() assert client._http_client._base_url == f"{AI21EnvConfig.api_host}/studio/v1" + + +def test_ai21_client__when_pass_ai21_api_host__should_add_suffix(): + ai21_url = "https://api.ai21.com" + client = AI21Client(api_host=ai21_url) + assert client._http_client._base_url == f"{ai21_url}/studio/v1" + + +def test_ai21_client__when_pass_ai21_with_suffix__should_not_modify(): + ai21_url = "https://api.ai21.com/studio/v1" + client = AI21Client(api_host=ai21_url) + assert client._http_client._base_url == ai21_url From fe104484810ca9137427d21295e2d4c6ffcf30fb Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Thu, 13 Jun 2024 16:40:06 +0300 Subject: [PATCH 4/5] fix: Extra line --- ai21/clients/studio/ai21_client.py | 3 ++- ai21/clients/studio/async_ai21_client.py | 3 ++- ai21/constants.py | 1 - 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index 1930bbb3..b6b9e2d9 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -19,6 +19,7 @@ from ai21.clients.studio.resources.studio_segmentation import StudioSegmentation from ai21.clients.studio.resources.studio_summarize import StudioSummarize from ai21.clients.studio.resources.studio_summarize_by_segment import StudioSummarizeBySegment +from ai21.constants import STUDIO_HOST from ai21.http_client.http_client import HttpClient from ai21.tokenizers.ai21_tokenizer import AI21Tokenizer from ai21.tokenizers.factory import get_tokenizer @@ -71,7 +72,7 @@ def __init__( self.beta = Beta(self._http_client) def _create_url(self, base_url: str) -> str: - allowed_urls = ["https://api-stage.ai21.com", "https://api.ai21.com"] + allowed_urls = ["https://api-stage.ai21.com", STUDIO_HOST] if base_url in allowed_urls: return f"{base_url}/studio/v1" diff --git a/ai21/clients/studio/async_ai21_client.py b/ai21/clients/studio/async_ai21_client.py index c67c7ee5..6e389435 100644 --- a/ai21/clients/studio/async_ai21_client.py +++ b/ai21/clients/studio/async_ai21_client.py @@ -2,6 +2,7 @@ from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.ai21_http_client.async_ai21_http_client import AsyncAI21HTTPClient +from ai21.constants import STUDIO_HOST from ai21.http_client.async_http_client import AsyncHttpClient from ai21.clients.studio.resources.studio_answer import AsyncStudioAnswer from ai21.clients.studio.resources.studio_chat import AsyncStudioChat @@ -65,7 +66,7 @@ def __init__( self.beta = AsyncBeta(self._http_client) def _create_url(self, base_url: str) -> str: - allowed_urls = ["https://api-stage.ai21.com", "https://api.ai21.com"] + allowed_urls = ["https://api-stage.ai21.com", STUDIO_HOST] if base_url in allowed_urls: return f"{base_url}/studio/v1" diff --git a/ai21/constants.py b/ai21/constants.py index 5d97b082..49e416f8 100644 --- a/ai21/constants.py +++ b/ai21/constants.py @@ -1,3 +1,2 @@ DEFAULT_API_VERSION = "v1" STUDIO_HOST = "https://api.ai21.com" -"https://api-stage.ai21.com" From 443a8aaa4f795c80e9b0f73c05fc3bed41815cbd Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Thu, 13 Jun 2024 16:41:21 +0300 Subject: [PATCH 5/5] fix: Moved to external function --- ai21/clients/studio/ai21_client.py | 12 ++---------- ai21/clients/studio/async_ai21_client.py | 16 ++++------------ ai21/clients/studio/client_url_parser.py | 10 ++++++++++ 3 files changed, 16 insertions(+), 22 deletions(-) create mode 100644 ai21/clients/studio/client_url_parser.py diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index b6b9e2d9..0d588f46 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -5,6 +5,7 @@ from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.ai21_http_client.ai21_http_client import AI21HTTPClient +from ai21.clients.studio.client_url_parser import create_client_url from ai21.clients.studio.resources.beta.beta import Beta from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_chat import StudioChat @@ -19,7 +20,6 @@ from ai21.clients.studio.resources.studio_segmentation import StudioSegmentation from ai21.clients.studio.resources.studio_summarize import StudioSummarize from ai21.clients.studio.resources.studio_summarize_by_segment import StudioSummarizeBySegment -from ai21.constants import STUDIO_HOST from ai21.http_client.http_client import HttpClient from ai21.tokenizers.ai21_tokenizer import AI21Tokenizer from ai21.tokenizers.factory import get_tokenizer @@ -44,7 +44,7 @@ def __init__( env_config: _AI21EnvConfig = AI21EnvConfig, **kwargs, ): - base_url = self._create_url(api_host or env_config.api_host) + base_url = create_client_url(api_host or env_config.api_host) self._http_client = AI21HTTPClient( api_key=api_key or env_config.api_key, @@ -71,14 +71,6 @@ def __init__( self.segmentation = StudioSegmentation(self._http_client) self.beta = Beta(self._http_client) - def _create_url(self, base_url: str) -> str: - allowed_urls = ["https://api-stage.ai21.com", STUDIO_HOST] - - if base_url in allowed_urls: - return f"{base_url}/studio/v1" - - return base_url - def count_tokens(self, text: str, tokenizer_name: str = PreTrainedTokenizers.J2_TOKENIZER) -> int: warnings.warn( "Please use the global get_tokenizer() method directly instead of the AI21Client().count_tokens() method.", diff --git a/ai21/clients/studio/async_ai21_client.py b/ai21/clients/studio/async_ai21_client.py index 6e389435..5ef59ce2 100644 --- a/ai21/clients/studio/async_ai21_client.py +++ b/ai21/clients/studio/async_ai21_client.py @@ -2,8 +2,8 @@ from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.ai21_http_client.async_ai21_http_client import AsyncAI21HTTPClient -from ai21.constants import STUDIO_HOST -from ai21.http_client.async_http_client import AsyncHttpClient +from ai21.clients.studio.client_url_parser import create_client_url +from ai21.clients.studio.resources.beta.async_beta import AsyncBeta from ai21.clients.studio.resources.studio_answer import AsyncStudioAnswer from ai21.clients.studio.resources.studio_chat import AsyncStudioChat from ai21.clients.studio.resources.studio_completion import AsyncStudioCompletion @@ -17,7 +17,7 @@ from ai21.clients.studio.resources.studio_segmentation import AsyncStudioSegmentation from ai21.clients.studio.resources.studio_summarize import AsyncStudioSummarize from ai21.clients.studio.resources.studio_summarize_by_segment import AsyncStudioSummarizeBySegment -from ai21.clients.studio.resources.beta.async_beta import AsyncBeta +from ai21.http_client.async_http_client import AsyncHttpClient class AsyncAI21Client: @@ -37,7 +37,7 @@ def __init__( env_config: _AI21EnvConfig = AI21EnvConfig, **kwargs, ): - base_url = self._create_url(api_host or env_config.api_host) + base_url = create_client_url(api_host or env_config.api_host) self._http_client = AsyncAI21HTTPClient( api_key=api_key or env_config.api_key, @@ -64,11 +64,3 @@ def __init__( self.library = AsyncStudioLibrary(self._http_client) self.segmentation = AsyncStudioSegmentation(self._http_client) self.beta = AsyncBeta(self._http_client) - - def _create_url(self, base_url: str) -> str: - allowed_urls = ["https://api-stage.ai21.com", STUDIO_HOST] - - if base_url in allowed_urls: - return f"{base_url}/studio/v1" - - return base_url diff --git a/ai21/clients/studio/client_url_parser.py b/ai21/clients/studio/client_url_parser.py new file mode 100644 index 00000000..921e3c1a --- /dev/null +++ b/ai21/clients/studio/client_url_parser.py @@ -0,0 +1,10 @@ +from ai21.constants import STUDIO_HOST + + +def create_client_url(base_url: str) -> str: + allowed_urls = ["https://api-stage.ai21.com", STUDIO_HOST] + + if base_url in allowed_urls: + return f"{base_url}/studio/v1" + + return base_url