Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,4 @@ fabric.properties
# https://github.com/github/gitignore/blob/master/Global/Patch.gitignore
*.orig
*.rej
__pycache__
2 changes: 1 addition & 1 deletion ai21/ai21_env_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ai21.constants import DEFAULT_API_VERSION, STUDIO_HOST


@dataclass(frozen=True)
@dataclass
class _AI21EnvConfig:
api_key: Optional[str] = None
api_version: str = DEFAULT_API_VERSION
Expand Down
3 changes: 2 additions & 1 deletion ai21/ai21_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def __init__(
self,
*,
api_key: Optional[str] = None,
requires_api_key: bool = True,
api_host: Optional[str] = None,
api_version: Optional[str] = None,
headers: Optional[Dict[str, Any]] = None,
Expand All @@ -21,7 +22,7 @@ def __init__(
):
self._api_key = api_key

if not self._api_key:
if requires_api_key and not self._api_key:
raise MissingApiKeyError()

self._api_host = api_host
Expand Down
1 change: 1 addition & 0 deletions ai21/services/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _create_ai21_http_client(cls) -> AI21HTTPClient:
return AI21HTTPClient(
api_key=AI21EnvConfig.api_key,
api_host=AI21EnvConfig.api_host,
requires_api_key=False,
api_version=AI21EnvConfig.api_version,
timeout_sec=AI21EnvConfig.timeout_sec,
num_retries=AI21EnvConfig.num_retries,
Expand Down
28 changes: 22 additions & 6 deletions tests/integration_tests/services/test_sagemaker.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,34 @@
import pytest

from ai21 import SageMaker
from tests.integration_tests.skip_helpers import should_skip_studio_integration_tests
from ai21 import SageMaker, AI21EnvConfig


@pytest.mark.skipif(should_skip_studio_integration_tests(), reason="No key supplied for AI21 Studio. Skipping.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a test where we set an env variable of AI21_API_KEY to validates that it still works?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

def test_sagemker__get_model_package_arn():
def _add_or_remove_api_key(use_api_key: bool):
if use_api_key:
AI21EnvConfig.api_key = "test"
else:
AI21EnvConfig.api_key = None


@pytest.mark.parametrize(
argnames="use_api_key",
argvalues=[True, False],
ids=["with_api_key", "without_api_key"],
)
def test_sagemaker__get_model_package_arn(use_api_key: bool):
_add_or_remove_api_key(use_api_key)
model_packages_arn = SageMaker.get_model_package_arn(model_name="j2-mid", region="us-east-1")
assert isinstance(model_packages_arn, str)
assert len(model_packages_arn) > 0


@pytest.mark.skipif(should_skip_studio_integration_tests(), reason="No key supplied for AI21 Studio. Skipping.")
def test_sagemker__list_model_package_versions():
@pytest.mark.parametrize(
argnames="use_api_key",
argvalues=[True, False],
ids=["with_api_key", "without_api_key"],
)
def test_sagemaker__list_model_package_versions(use_api_key: bool):
_add_or_remove_api_key(use_api_key)
model_packages_arn = SageMaker.list_model_package_versions(model_name="j2-mid", region="us-east-1")
assert isinstance(model_packages_arn, list)
assert len(model_packages_arn) > 0