Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ai21/clients/studio/ai21_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig
from ai21.ai21_http_client.ai21_http_client import AI21HTTPClient
from ai21.clients.studio.client_url_parser import create_client_url
from ai21.clients.studio.resources.beta.beta import Beta
from ai21.clients.studio.resources.studio_answer import StudioAnswer
from ai21.clients.studio.resources.studio_chat import StudioChat
Expand Down Expand Up @@ -43,11 +44,11 @@ def __init__(
env_config: _AI21EnvConfig = AI21EnvConfig,
**kwargs,
):
base_url = api_host or env_config.api_host
base_url = create_client_url(api_host or env_config.api_host)

self._http_client = AI21HTTPClient(
api_key=api_key or env_config.api_key,
base_url=f"{base_url}/studio/v1",
base_url=base_url,
api_version=env_config.api_version,
headers=headers,
timeout_sec=timeout_sec or env_config.timeout_sec,
Expand Down
9 changes: 5 additions & 4 deletions ai21/clients/studio/async_ai21_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig
from ai21.ai21_http_client.async_ai21_http_client import AsyncAI21HTTPClient
from ai21.http_client.async_http_client import AsyncHttpClient
from ai21.clients.studio.client_url_parser import create_client_url
from ai21.clients.studio.resources.beta.async_beta import AsyncBeta
from ai21.clients.studio.resources.studio_answer import AsyncStudioAnswer
from ai21.clients.studio.resources.studio_chat import AsyncStudioChat
from ai21.clients.studio.resources.studio_completion import AsyncStudioCompletion
Expand All @@ -16,7 +17,7 @@
from ai21.clients.studio.resources.studio_segmentation import AsyncStudioSegmentation
from ai21.clients.studio.resources.studio_summarize import AsyncStudioSummarize
from ai21.clients.studio.resources.studio_summarize_by_segment import AsyncStudioSummarizeBySegment
from ai21.clients.studio.resources.beta.async_beta import AsyncBeta
from ai21.http_client.async_http_client import AsyncHttpClient


class AsyncAI21Client:
Expand All @@ -36,11 +37,11 @@ def __init__(
env_config: _AI21EnvConfig = AI21EnvConfig,
**kwargs,
):
base_url = api_host or env_config.api_host
base_url = create_client_url(api_host or env_config.api_host)

self._http_client = AsyncAI21HTTPClient(
api_key=api_key or env_config.api_key,
base_url=f"{base_url}/studio/v1",
base_url=base_url,
api_version=env_config.api_version,
headers=headers,
timeout_sec=timeout_sec or env_config.timeout_sec,
Expand Down
10 changes: 10 additions & 0 deletions ai21/clients/studio/client_url_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from ai21.constants import STUDIO_HOST


def create_client_url(base_url: str) -> str:
allowed_urls = ["https://api-stage.ai21.com", STUDIO_HOST]

if base_url in allowed_urls:
return f"{base_url}/studio/v1"

return base_url
29 changes: 29 additions & 0 deletions tests/unittests/clients/studio/test_ai21_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
from ai21 import AsyncAI21Client, AI21EnvConfig


@pytest.mark.asyncio
def test_async_ai21_client__when_pass_api_host__should_leave_as_is():
base_url = "https://dont-modify-me.com"
client = AsyncAI21Client(api_host=base_url)
assert client._http_client._base_url == base_url


@pytest.mark.asyncio
def test_async_ai21_client__when_not_pass_api_host__should_add_suffix():
client = AsyncAI21Client()
assert client._http_client._base_url == f"{AI21EnvConfig.api_host}/studio/v1"


@pytest.mark.asyncio
def test_async_ai21_client__when_pass_ai21_api_host__should_add_suffix():
ai21_url = "https://api.ai21.com"
client = AsyncAI21Client(api_host=ai21_url)
assert client._http_client._base_url == f"{ai21_url}/studio/v1"


@pytest.mark.asyncio
def test_async_ai21_client__when_pass_ai21_with_suffix__should_not_modify():
ai21_url = "https://api.ai21.com/studio/v1"
client = AsyncAI21Client(api_host=ai21_url)
assert client._http_client._base_url == ai21_url
24 changes: 24 additions & 0 deletions tests/unittests/clients/studio/test_async_ai21_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from ai21 import AI21Client, AI21EnvConfig


def test_ai21_client__when_pass_api_host__should_leave_as_is():
base_url = "https://dont-modify-me.com"
client = AI21Client(api_host=base_url)
assert client._http_client._base_url == base_url


def test_ai21_client__when_not_pass_api_host__should_add_suffix():
client = AI21Client()
assert client._http_client._base_url == f"{AI21EnvConfig.api_host}/studio/v1"


def test_ai21_client__when_pass_ai21_api_host__should_add_suffix():
ai21_url = "https://api.ai21.com"
client = AI21Client(api_host=ai21_url)
assert client._http_client._base_url == f"{ai21_url}/studio/v1"


def test_ai21_client__when_pass_ai21_with_suffix__should_not_modify():
ai21_url = "https://api.ai21.com/studio/v1"
client = AI21Client(api_host=ai21_url)
assert client._http_client._base_url == ai21_url