From 2e87234e1d8b2c0a992c0a4e4d2b95727bd63209 Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Tue, 19 Dec 2023 11:16:19 +0200 Subject: [PATCH 01/45] ci: Change main proj name (#22) * ci: precommit * fix: version name * ci: Added rc pipeline for semantic release * ci: pypi api keys * ci: Added rc pipeline for semantic release * ci: pypi api keys --- .github/workflows/publish.yaml | 3 +-- pyproject.toml | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 3250a378..e27ee41b 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -37,5 +37,4 @@ jobs: uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 with: user: __token__ - password: ${{ secrets.TEST_PYPI_TOKEN }} - repository_url: https://test.pypi.org/legacy/ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/pyproject.toml b/pyproject.toml index de30502a..9ec202da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,8 +106,8 @@ commit_message = "chore(release): v{version} [skip ci]" [tool.semantic_release.branches.main] match = "(main)" + +[tool.semantic_release.branches."Release Candidates"] +match = "(rc_*)" prerelease_token = "rc" prerelease = true - -[tool.semantic_release.changelog.environment] -newline_sequence = "\n" From ef30164b09497ff8a7134e8868e1df9980104613 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Tue, 19 Dec 2023 11:23:26 +0200 Subject: [PATCH 02/45] fix: readme --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index cdcdbfb6..74e5ed23 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ # AI21 Labs Python SDK -This repository includes the SDK for the AI21 Labs API. The SDK is a Python package that can be installed using pip. +This repository includes the SDK for the AI21 Labs API. The SDK is a Python +package that can be installed using `pip install ai21`. From 2f53ec9abebd580a33c382cd1d544d234c74dbbf Mon Sep 17 00:00:00 2001 From: github-actions Date: Tue, 19 Dec 2023 09:24:34 +0000 Subject: [PATCH 03/45] chore(release): v2.0.0-rc.4 [skip ci] --- CHANGELOG.md | 27 +++++++++++++++++++++++++++ ai21/version.py | 2 +- pyproject.toml | 2 +- 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 57959ac5..3d71390a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,8 +2,35 @@ +## v2.0.0-rc.4 (2023-12-19) + +### Ci + +* ci: Change main proj name (#22) + +* ci: precommit + +* fix: version name + +* ci: Added rc pipeline for semantic release + +* ci: pypi api keys + +* ci: Added rc pipeline for semantic release + +* ci: pypi api keys ([`2e87234`](https://github.com/AI21Labs/ai21-python/commit/2e87234e1d8b2c0a992c0a4e4d2b95727bd63209)) + +### Fix + +* fix: readme ([`ef30164`](https://github.com/AI21Labs/ai21-python/commit/ef30164b09497ff8a7134e8868e1df9980104613)) + + ## v2.0.0-rc.3 (2023-12-18) +### Chore + +* chore(release): v2.0.0-rc.3 [skip ci] ([`0a9ace7`](https://github.com/AI21Labs/ai21-python/commit/0a9ace7dd8b59eb51b6dcb4e4a1118aaa012b454)) + ### Fix * fix: Change main project name in setup (#17) diff --git a/ai21/version.py b/ai21/version.py index ae5502ec..acd37e4a 100644 --- a/ai21/version.py +++ b/ai21/version.py @@ -1 +1 @@ -VERSION = "2.0.0-rc.3" +VERSION = "2.0.0-rc.4" diff --git a/pyproject.toml b/pyproject.toml index 9ec202da..ae241312 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ exclude_lines = [ [tool.poetry] name = "ai21" -version = "2.0.0-rc.3" +version = "2.0.0-rc.4" description = "" authors = ["AI21 Labs"] readme = "README.md" From bbb87d351c6f71ead0616c8bb90b1715285861a6 Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Wed, 20 Dec 2023 14:21:11 +0200 Subject: [PATCH 04/45] docs: README.md (#23) * ci: updated precommit hooks * docs: more readme updates * docs: removed extra lines * fix: rename * docs: readme * docs: full readme * docs: badges * ci: commitizen version * revert: via --- .pre-commit-config.yaml | 157 ++++------- README.md | 297 +++++++++++++++++++- ai21/__init__.py | 3 + ai21/clients/bedrock/bedrock_session.py | 4 +- ai21/clients/sagemaker/sagemaker_session.py | 7 +- ai21/clients/studio/ai21_client.py | 2 +- ai21/errors.py | 20 +- ai21/http_client.py | 8 +- examples/studio/library.py | 13 +- examples/studio/tokenization.py | 2 +- pyproject.toml | 3 + 11 files changed, 380 insertions(+), 136 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 42b3a452..28d2342b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,87 +1,55 @@ -# yaml-language-server: $schema=https://json.schemastore.org/pre-commit-config.json - minimum_pre_commit_version: 2.20.0 fail_fast: false default_stages: - commit +exclude: (.idea|vscode) repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.4.0 hooks: - id: check-added-large-files + exclude: (ai21_tokenizer/resources|tests/resources) - id: check-case-conflict - id: check-executables-have-shebangs - id: check-shebang-scripts-are-executable - id: check-merge-conflict - id: check-symlinks - id: detect-private-key - exclude: .gitleaks.toml - id: no-commit-to-branch + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: - id: trailing-whitespace - id: end-of-file-fixer - id: mixed-line-ending exclude: (CHANGELOG.md) - - id: check-json - - id: check-toml - - id: check-xml - - id: pretty-format-json - args: - - --autofix - - --no-sort-keys - repo: https://github.com/jumanjihouse/pre-commit-hooks - rev: 2.1.5 + rev: 3.0.0 hooks: - id: forbid-binary + exclude: (ai21_tokenizer/resources|tests/resources) - id: git-check files: "CHANGELOG.md" - - repo: https://github.com/adrienverge/yamllint - rev: v1.26.3 + - repo: https://github.com/commitizen-tools/commitizen + rev: v3.5.3 hooks: - - id: yamllint - name: Lint YAML files - args: - - --strict - - repo: https://github.com/sirosen/check-jsonschema - rev: 0.4.1 + - id: commitizen + name: Lint commit message + stages: + - commit-msg + - repo: https://github.com/python-jsonschema/check-jsonschema + rev: 0.23.3 hooks: - id: check-jsonschema - name: Validate GitHub Workflows - files: ^\.github/workflows/.*\.yml - types: - - yaml - args: - - --schemafile - - https://json.schemastore.org/github-workflow.json - - id: check-jsonschema - name: Validate GitHub Actions - files: > - (?x)^( - .*/action\.(yml|yaml)| - \.github/actions/.* - )$ - types: - - yaml - args: - - --schemafile - - https://json.schemastore.org/github-action - - id: check-jsonschema - name: Validate DependaBot - files: ^\.github/dependabot\.yml - types: - - yaml - args: - - --schemafile - - https://json.schemastore.org/dependabot-2.0.json - - id: check-jsonschema - name: Validate MarkdownLint - files: .*\.markdownlint\.yaml + name: Validate Pre-commit + files: .*\.pre-commit-config\.yaml types: - yaml args: - --schemafile - - https://json.schemastore.org/markdownlint.json + - https://json.schemastore.org/pre-commit-config.json - id: check-jsonschema - name: Validate YamlLint + name: Validate YamlLint configuration files: .*\.yamllint\.yaml types: - yaml @@ -89,38 +57,37 @@ repos: - --schemafile - https://json.schemastore.org/yamllint.json - id: check-jsonschema - name: Validate Pre-commit - files: .*\.pre-commit-config\.yaml + name: Validate Prettier configuration + files: .*\.prettierrc\.yaml types: - yaml args: - --schemafile - - https://json.schemastore.org/pre-commit-config.json + - http://json.schemastore.org/prettierrc - id: check-jsonschema - name: Validate Docker-Compose - files: .*docker-compose\.yml + name: Validate ArgoWorkflow files + files: ^workflows/template/.* types: - yaml args: + - --verbose - --schemafile - - https://raw.githubusercontent.com/compose-spec/compose-spec/master/schema/compose-spec.json - - id: check-jsonschema - name: Validate Renovate - files: ^\.github/renovate\.json - types: - - json - args: - - --schemafile - - https://docs.renovatebot.com/renovate-schema.json - - repo: https://github.com/commitizen-tools/commitizen - rev: v2.18.0 + - https://raw.githubusercontent.com/argoproj/argo-workflows/master/api/jsonschema/schema.json + - repo: https://github.com/python-poetry/poetry + rev: 1.5.0 hooks: - - id: commitizen - name: Lint commit message - stages: - - commit-msg + - id: poetry-check + - repo: https://github.com/adrienverge/yamllint + rev: v1.32.0 + hooks: + - id: yamllint + name: Lint YAML files + args: + - --format + - parsable + - --strict - repo: https://github.com/shellcheck-py/shellcheck-py - rev: v0.7.2.1 + rev: v0.9.0.5 hooks: - id: shellcheck name: Check sh files (and patch) @@ -140,45 +107,23 @@ repos: - yaml - markdown - shell - - repo: local + - repo: https://github.com/psf/black + rev: 23.7.0 hooks: - - id: list-files - name: List files - language: system - entry: bash -c 'echo $@' - stages: - - manual - - id: shfmt - name: Format sh files - language: docker_image - entry: mvdan/shfmt:v3.4.0 - args: - - -w - - -s - - -i - - "2" + - id: black types: - - shell - - id: markdownlint - name: Lint Markdown files - language: docker_image - entry: 06kellyjac/markdownlint-cli:0.28.1 + - python + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.0.280 + hooks: + - id: ruff args: - --fix - types: - - markdown + - repo: local + hooks: - id: hadolint name: Lint Dockerfiles language: docker_image - entry: hadolint/hadolint:v2.8.0 hadolint + entry: hadolint/hadolint:v2.10.0 hadolint types: - dockerfile - - id: gitleaks - name: Detect hardcoded secrets - language: docker_image - entry: zricethezav/gitleaks:v7.6.1 - args: - - --append-repo-config - - --config-path - - .gitleaks.toml - - --verbose diff --git a/README.md b/README.md index 74e5ed23..2bac70a8 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,295 @@ -# AI21 Labs Python SDK +

+ AI21 Labs Python SDK +

-This repository includes the SDK for the AI21 Labs API. The SDK is a Python -package that can be installed using `pip install ai21`. +[//]: # "Add when public" +[//]: # 'Test' +[//]: # 'Supported Python versions' + +

+Package version +Poetry +Supported Python versions +License +

+ +--- + +## Migration from v1.3.3 and below + +In `v2.0.0` we introduced a new SDK that is not backwards compatible with the previous version. +This version allows for Non static instances of the client, defined parameters to each resource, modelized responses and +more. + +
+Migration Examples + +### Instance creation (not available in v1.3.3 and below) + +```python +from ai21 import AI21Client + +client = AI21Client(api_key='my_api_key') + +# or set api_key in environment variable - AI21_API_KEY and then +client = AI21Client() +``` + +### Completion before/after + +```diff +prompt = "some prompt" + +import ai21 + +- response = ai21.Completion.execute(model="j2-light", prompt=prompt, maxTokens=2) + + ++ client = ai21.AI21Client() ++ response = client.completion(model="j2-light", prompt=prompt, max_tokens=2) +``` + +### Tokenization and Token counting before/after + +```diff +- response = ai21.Tokenization.execute(text=prompt) +- print(len(response)) # number of tokens + ++ client = ai21.AI21Client() ++ token_count = client.count_tokens(text=prompt) +``` + +--- + +### AWS Client Creations + +### Bedrock Client creation before/after + +```diff +- import ai21 +- destination = ai21.BedrockDestination(model_id=ai21.BedrockModelID.J2_MID_V1) +- response = ai21.Completion.execute(prompt=prompt, maxTokens=1000, destination=destination) + ++ from ai21 import AI21BedrockClient, BedrockModelID ++ client = AI21BedrockClient() ++ response = client.completion.create(prompt=prompt, max_tokens=1000, model_id=BedrockModelID.J2_MID_V1) +``` + +### SageMaker Client creation before/after + +```diff +- import ai21 +- destination = ai21.SageMakerDestination("j2-mid-test-endpoint") +- response = ai21.Completion.execute(prompt=prompt, maxTokens=1000, destination=destination) + ++ from ai21 import AI21SageMakerClient ++ client = AI21SageMakerClient(endpoint_name="j2-mid-test-endpoint") ++ response = client.completion.create(prompt=prompt, max_tokens=1000) +``` + +
+ +## Installation + +### pip + +```bash +pip install ai21 +``` + +## Usage + +--- + +### Client Instance Creation + +```python +from ai21 import AI21Client + +client = AI21Client( + # defaults to os.enviorn.get('AI21_API_KEY') + api_key='my_api_key', +) + +response = client.completion.create( + prompt="", + max_tokens=10, + model="j2-mid", + temperature=0.3, + top_p=1, +) + +print(response.completions) +print(response.prompt) +``` + +### Token Counting + +--- + +By using the `count_tokens` method, you can estimate the billing for a given request. + +```python +from ai21 import AI21Client + +client = AI21Client() +client.count_tokens(text="some text") # returns int +``` + +### File Upload + +--- + +```python +from ai21 import AI21Client + +client = AI21Client() + +file_id = client.library.files.upload( + file_path="path/to/file", + path="path/to/file/in/library", + labels=["label1", "label2"], + public_url="www.example.com", +) + +uploaded_file = client.library.files.get(file_id) +``` + +## Environment Variables + +--- + +You can set several environment variables to configure the client. + +### Logging + +We use the standard library [`logging`](https://docs.python.org/3/library/logging.html) module. + +To enable logging, set the `AI21_LOG_LEVEL` environment variable. + +```bash +$ export AI21_LOG_LEVEL=debug +``` + +### Other Important Environment Variables + +- `AI21_API_KEY` - Your API key. If not set, you must pass it to the client constructor. +- `AI21_API_URL` - The base URL of the API. Defaults to `https://api.ai21.com/v1/`. +- `AI21_API_VERSION` - The API version. Defaults to `v1`. +- `AI21_API_HOST` - The API host. Defaults to `api.ai21.com`. +- `AI21_TIMEOUT_SEC` - The timeout for API requests. +- `AI21_NUM_RETRIES` - The maximum number of retries for API requests. Defaults to `3` retries. +- `AI21_AWS_REGION` - The AWS region to use for AWS clients. Defaults to `us-east-1`. + +## Error Handling + +--- + +```python +from ai21 import errors as ai21_errors +from ai21 import AI21Client, AI21APIError + +client = AI21Client() + +system = "You're a support engineer in a SaaS company" +messages = [ + { + "text": "Hello, I need help with a signup process.", + "role": "user", + "name": "Alice", + }, +] + +try: + chat_completion = client.chat.create( + messages=messages, + model="j2-ultra", + system=system + ) +except ai21_errors.AI21ServerError as e: + print("Server error and could not be reached") + print(e.details) +except ai21_errors.TooManyRequests as e: + print("A 429 status code was returned. Slow down on the requests") +except AI21APIError as e: + print("A non 200 status code error. For more error types see ai21.errors") + +``` + +## AWS Clients + +--- + +AI21 Library provides convenient ways to interact with two AWS clients for use with AWS SageMaker and AWS Bedrock. + +### Installation + +--- + +```bash +pip install "ai21[AWS]" +``` + +### Usage + +--- + +#### SageMaker + +```python +from ai21 import AI21SageMakerClient + +client = AI21SageMakerClient(endpoint_name="j2-endpoint-name") +response = client.summarize.create( + source="Text to summarize", + source_type="TEXT", +) +print(response.summary) +``` + +#### With Boto3 Session + +```python +from ai21 import AI21SageMakerClient +import boto3 +sm_session = boto3.Session(region_name="us-east-1") + +client = AI21SageMakerClient( + session=sm_session, + endpoint_name="j2-endpoint-name", +) +``` + +#### Bedrock + +--- + +```python +from ai21 import AI21BedrockClient, BedrockModelID + +client = AI21BedrockClient(region='us-east-2') # region is optional, as you can use the env variable instead +response = client.completion.create( + prompt="Your prompt here", + model_id=BedrockModelID.J2_MID_V1, + max_tokens=10, +) +print(response.completions[0].data.text) +``` + +#### With Boto3 Session + +```python +from ai21 import AI21BedrockClient, BedrockModelID +import boto3 +bedrock_session = boto3.Session(region_name="us-east-2") + +client = AI21BedrockClient( + session=bedrock_session, +) + +response = client.completion.create( + prompt="Your prompt here", + model_id=BedrockModelID.J2_MID_V1, + max_tokens=10, +) +``` diff --git a/ai21/__init__.py b/ai21/__init__.py index d5ea25a4..d614a1ca 100644 --- a/ai21/__init__.py +++ b/ai21/__init__.py @@ -1,6 +1,7 @@ from typing import Any from ai21.clients.studio.ai21_client import AI21Client +from ai21.errors import AI21APIError, AI21APITimeoutError from ai21.logger import setup_logger from ai21.resources.responses.answer_response import AnswerResponse from ai21.resources.responses.chat_response import ChatResponse @@ -58,6 +59,8 @@ def __getattr__(name: str) -> Any: __all__ = [ "AI21Client", + "AI21APIError", + "AI21APITimeoutError", "AI21BedrockClient", "AI21SageMakerClient", "BedrockModelID", diff --git a/ai21/clients/bedrock/bedrock_session.py b/ai21/clients/bedrock/bedrock_session.py index 82029da6..7d9f846c 100644 --- a/ai21/clients/bedrock/bedrock_session.py +++ b/ai21/clients/bedrock/bedrock_session.py @@ -6,7 +6,7 @@ from botocore.exceptions import ClientError from ai21.logger import logger -from ai21.errors import AccessDenied, NotFound, APITimeoutError +from ai21.errors import AccessDenied, NotFound, AI21APITimeoutError from ai21.http_client import handle_non_success_response _ERROR_MSG_TEMPLATE = ( @@ -52,7 +52,7 @@ def _handle_client_error(self, client_exception: ClientError) -> None: raise NotFound(details=error_message) if status_code == 408: - raise APITimeoutError(details=error_message) + raise AI21APITimeoutError(details=error_message) if status_code == 424: error_message_template = re.compile(_ERROR_MSG_TEMPLATE) diff --git a/ai21/clients/sagemaker/sagemaker_session.py b/ai21/clients/sagemaker/sagemaker_session.py index c94c2576..031aefe7 100644 --- a/ai21/clients/sagemaker/sagemaker_session.py +++ b/ai21/clients/sagemaker/sagemaker_session.py @@ -4,7 +4,7 @@ import boto3 from botocore.exceptions import ClientError -from ai21.errors import BadRequest, ServiceUnavailable, ServerError, APIError +from ai21.errors import BadRequest, ServiceUnavailable, AI21ServerError, AI21APIError from ai21.http_client import handle_non_success_response from ai21.logger import logger @@ -25,7 +25,6 @@ def invoke_endpoint( self, input_json: str, ): - try: response = self._session.invoke_endpoint( EndpointName=self._endpoint_name, @@ -56,5 +55,5 @@ def _handle_client_error(self, client_exception: "ClientError"): if status_code == 429 or status_code == 503: raise ServiceUnavailable(details=error_message) if status_code == 500: - raise ServerError(details=error_message) - raise APIError(status_code, details=error_message) + raise AI21ServerError(details=error_message) + raise AI21APIError(status_code, details=error_message) diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index 154da889..025a3bab 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -57,7 +57,7 @@ def __init__( self.library = StudioLibrary(studio_client) self.segmentation = StudioSegmentation(studio_client) - def count_token(self, text: str) -> int: + def count_tokens(self, text: str) -> int: # We might want to cache the tokenizer instance within the class # and not globally as it might be used by other instances diff --git a/ai21/errors.py b/ai21/errors.py index 2232b3ae..33cf336b 100644 --- a/ai21/errors.py +++ b/ai21/errors.py @@ -1,7 +1,7 @@ from typing import Optional -class APIError(Exception): +class AI21APIError(Exception): def __init__(self, status_code: int, details: Optional[str] = None): super().__init__(details) self.details = details @@ -11,47 +11,47 @@ def __str__(self) -> str: return f"Failed with http status code: {self.status_code} ({type(self).__name__}). Details: {self.details}" -class BadRequest(APIError): +class BadRequest(AI21APIError): def __init__(self, details: Optional[str] = None): super().__init__(400, details) -class Unauthorized(APIError): +class Unauthorized(AI21APIError): def __init__(self, details: Optional[str] = None): super().__init__(401, details) -class AccessDenied(APIError): +class AccessDenied(AI21APIError): def __init__(self, details: Optional[str] = None): super().__init__(403, details) -class NotFound(APIError): +class NotFound(AI21APIError): def __init__(self, details: Optional[str] = None): super().__init__(404, details) -class APITimeoutError(APIError): +class AI21APITimeoutError(AI21APIError): def __init__(self, details: Optional[str] = None): super().__init__(408, details) -class UnprocessableEntity(APIError): +class UnprocessableEntity(AI21APIError): def __init__(self, details: Optional[str] = None): super().__init__(422, details) -class TooManyRequests(APIError): +class TooManyRequests(AI21APIError): def __init__(self, details: Optional[str] = None): super().__init__(429, details) -class ServerError(APIError): +class AI21ServerError(AI21APIError): def __init__(self, details: Optional[str] = None): super().__init__(500, details) -class ServiceUnavailable(APIError): +class ServiceUnavailable(AI21APIError): def __init__(self, details: Optional[str] = None): super().__init__(500, details) diff --git a/ai21/http_client.py b/ai21/http_client.py index 0a0e4c4b..7f9b0286 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -10,9 +10,9 @@ Unauthorized, UnprocessableEntity, TooManyRequests, - ServerError, + AI21ServerError, ServiceUnavailable, - APIError, + AI21APIError, ) DEFAULT_TIMEOUT_SEC = 300 @@ -33,10 +33,10 @@ def handle_non_success_response(status_code: int, response_text: str): if status_code == 429: raise TooManyRequests(details=response_text) if status_code == 500: - raise ServerError(details=response_text) + raise AI21ServerError(details=response_text) if status_code == 503: raise ServiceUnavailable(details=response_text) - raise APIError(status_code, details=response_text) + raise AI21APIError(status_code, details=response_text) def requests_retry_session(session, retries=0): diff --git a/examples/studio/library.py b/examples/studio/library.py index 5ace905b..e1377200 100644 --- a/examples/studio/library.py +++ b/examples/studio/library.py @@ -3,7 +3,7 @@ import file_utils from ai21 import AI21Client -from ai21.errors import APIError +from ai21.errors import AI21APIError # Use api_host for testing staging, default is production # os.environ["AI21_API_HOST"] = "https://api-stage.ai21.com" @@ -14,7 +14,7 @@ def validate_file_deleted(): try: client.library.files.get(file_id) - except APIError as e: + except AI21APIError as e: print(f"File not found. Exception: {e.details}") @@ -24,7 +24,6 @@ def validate_file_deleted(): path = os.path.join(file_path, file_name) file_utils.create_file(file_path, file_name, content="test content" * 100) - file_id = client.library.files.upload( file_path=path, path=file_path, @@ -41,7 +40,11 @@ def validate_file_deleted(): print(uploaded_file.labels) print(uploaded_file.public_url) -client.library.files.update(file_id, publicUrl="www.example-updated.com", labels=["label3", "label4"]) +client.library.files.update( + file_id, + publicUrl="www.example-updated.com", + labels=["label3", "label4"], +) updated_file = client.library.files.get(file_id) print(updated_file.name) print(updated_file.public_url) @@ -50,7 +53,7 @@ def validate_file_deleted(): client.library.files.delete(file_id) try: uploaded_file = client.library.files.get(file_id) -except APIError as e: +except AI21APIError as e: print(f"File not found. Exception: {e.details}") # Cleanup created file diff --git a/examples/studio/tokenization.py b/examples/studio/tokenization.py index 4b5e798d..21407185 100644 --- a/examples/studio/tokenization.py +++ b/examples/studio/tokenization.py @@ -33,5 +33,5 @@ ) client = AI21Client() # This is the new and recommended way to use the Tokenization module. The old "execute" method is deprecated. -response = client.count_token(prompt) +response = client.count_tokens(prompt) print(response) diff --git a/pyproject.toml b/pyproject.toml index ae241312..a20a5e06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,3 +111,6 @@ match = "(main)" match = "(rc_*)" prerelease_token = "rc" prerelease = true + +[tool.ruff] +line-length = 120 From 53c53cbd2f10556b54f41e313ec1cb4485657358 Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Wed, 20 Dec 2023 14:58:48 +0200 Subject: [PATCH 05/45] docs: Readme migration (#24) * docs: instance text * docs: client instance explanation --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 2bac70a8..cd5f74ad 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,9 @@ client = AI21Client(api_key='my_api_key') client = AI21Client() ``` +We No longer support static methods for each resource, instead we have a client instance that has a method for each +allowing for more flexibility and better control. + ### Completion before/after ```diff @@ -49,6 +52,8 @@ import ai21 + response = client.completion(model="j2-light", prompt=prompt, max_tokens=2) ``` +This applies to all resources. You would now need to create a client instance and use it to call the resource method. + ### Tokenization and Token counting before/after ```diff From d166253a10456725107e80af9c13f9b360a47b0f Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Wed, 20 Dec 2023 15:13:50 +0200 Subject: [PATCH 06/45] ci: Remove python 3_7 support (#25) * ci: python 3.8 >= support * ci: unittests * test: dummy test --- .github/workflows/test.yaml | 4 +- poetry.lock | 88 +---------------------------------- pyproject.toml | 2 +- tests/unittests/__init__.py | 0 tests/unittests/test_dummy.py | 2 + 5 files changed, 7 insertions(+), 89 deletions(-) create mode 100644 tests/unittests/__init__.py create mode 100644 tests/unittests/test_dummy.py diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index d1078370..3e454388 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -44,7 +44,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - name: Checkout uses: actions/checkout@v3 @@ -69,7 +69,7 @@ jobs: AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} run: | - poetry run pytest + poetry run pytest tests/unittests/ - name: Upload pytest test results uses: actions/upload-artifact@v3 with: diff --git a/poetry.lock b/poetry.lock index 9da530df..334d859e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -55,7 +55,6 @@ mypy-extensions = ">=0.4.3" pathspec = ">=0.9.0" platformdirs = ">=2" tomli = {version = ">=1.1.0", markers = "python_full_version < \"3.11.0a7\""} -typed-ast = {version = ">=1.4.2", markers = "python_version < \"3.8\" and implementation_name == \"cpython\""} typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""} [package.extras] @@ -228,7 +227,6 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} -importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} [[package]] name = "colorama" @@ -327,7 +325,6 @@ files = [ [package.dependencies] gitdb = ">=4.0.1,<5" -typing-extensions = {version = ">=3.7.4.3", markers = "python_version < \"3.8\""} [package.extras] test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-instafail", "pytest-subtests", "pytest-sugar"] @@ -343,26 +340,6 @@ files = [ {file = "idna-3.6.tar.gz", hash = "sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca"}, ] -[[package]] -name = "importlib-metadata" -version = "6.7.0" -description = "Read metadata from Python packages" -optional = false -python-versions = ">=3.7" -files = [ - {file = "importlib_metadata-6.7.0-py3-none-any.whl", hash = "sha256:cb52082e659e97afc5dac71e79de97d8681de3aa07ff18578330904a9d18e5b5"}, - {file = "importlib_metadata-6.7.0.tar.gz", hash = "sha256:1aaf550d4f73e5d6783e7acb77aec43d49da8017410afae93822cc9cca98c4d4"}, -] - -[package.dependencies] -typing-extensions = {version = ">=3.6.4", markers = "python_version < \"3.8\""} -zipp = ">=0.5" - -[package.extras] -docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] - [[package]] name = "importlib-resources" version = "5.12.0" @@ -461,7 +438,6 @@ files = [ [package.dependencies] mdurl = ">=0.1,<1.0" -typing_extensions = {version = ">=3.7.4", markers = "python_version < \"3.8\""} [package.extras] benchmarking = ["psutil", "pytest", "pytest-benchmark"] @@ -611,7 +587,6 @@ files = [ [package.dependencies] mypy-extensions = ">=1.0.0" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typed-ast = {version = ">=1.4.0,<2", markers = "python_version < \"3.8\""} typing-extensions = ">=4.1.0" [package.extras] @@ -667,9 +642,6 @@ files = [ {file = "platformdirs-4.0.0.tar.gz", hash = "sha256:cb633b2bcf10c51af60beb0ab06d2f1d69064b43abf4c185ca6b28865f3f9731"}, ] -[package.dependencies] -typing-extensions = {version = ">=4.7.1", markers = "python_version < \"3.8\""} - [package.extras] docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-autodoc-typehints (>=1.24)"] test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"] @@ -685,9 +657,6 @@ files = [ {file = "pluggy-1.2.0.tar.gz", hash = "sha256:d12f0c4b579b15f5e054301bb226ee85eeeba08ffec228092f8defbaa3a4c4b3"}, ] -[package.dependencies] -importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} - [package.extras] dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] @@ -705,7 +674,6 @@ files = [ [package.dependencies] annotated-types = ">=0.4.0" -importlib-metadata = {version = "*", markers = "python_version == \"3.7\""} pydantic-core = "2.14.5" typing-extensions = ">=4.6.1" @@ -872,7 +840,6 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "sys_platform == \"win32\""} exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} -importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} iniconfig = "*" packaging = "*" pluggy = ">=0.12,<2.0" @@ -909,7 +876,6 @@ files = [ [package.dependencies] requests = ">=2.25.0" requests-toolbelt = ">=0.10.1" -typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.8\""} [package.extras] autocompletion = ["argcomplete (>=1.10.0,<3)"] @@ -1266,56 +1232,6 @@ files = [ {file = "tomlkit-0.12.3.tar.gz", hash = "sha256:75baf5012d06501f07bee5bf8e801b9f343e7aac5a92581f20f80ce632e6b5a4"}, ] -[[package]] -name = "typed-ast" -version = "1.5.5" -description = "a fork of Python 2 and 3 ast modules with type comment support" -optional = false -python-versions = ">=3.6" -files = [ - {file = "typed_ast-1.5.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4bc1efe0ce3ffb74784e06460f01a223ac1f6ab31c6bc0376a21184bf5aabe3b"}, - {file = "typed_ast-1.5.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5f7a8c46a8b333f71abd61d7ab9255440d4a588f34a21f126bbfc95f6049e686"}, - {file = "typed_ast-1.5.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:597fc66b4162f959ee6a96b978c0435bd63791e31e4f410622d19f1686d5e769"}, - {file = "typed_ast-1.5.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d41b7a686ce653e06c2609075d397ebd5b969d821b9797d029fccd71fdec8e04"}, - {file = "typed_ast-1.5.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5fe83a9a44c4ce67c796a1b466c270c1272e176603d5e06f6afbc101a572859d"}, - {file = "typed_ast-1.5.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d5c0c112a74c0e5db2c75882a0adf3133adedcdbfd8cf7c9d6ed77365ab90a1d"}, - {file = "typed_ast-1.5.5-cp310-cp310-win_amd64.whl", hash = "sha256:e1a976ed4cc2d71bb073e1b2a250892a6e968ff02aa14c1f40eba4f365ffec02"}, - {file = "typed_ast-1.5.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c631da9710271cb67b08bd3f3813b7af7f4c69c319b75475436fcab8c3d21bee"}, - {file = "typed_ast-1.5.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b445c2abfecab89a932b20bd8261488d574591173d07827c1eda32c457358b18"}, - {file = "typed_ast-1.5.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc95ffaaab2be3b25eb938779e43f513e0e538a84dd14a5d844b8f2932593d88"}, - {file = "typed_ast-1.5.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61443214d9b4c660dcf4b5307f15c12cb30bdfe9588ce6158f4a005baeb167b2"}, - {file = "typed_ast-1.5.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6eb936d107e4d474940469e8ec5b380c9b329b5f08b78282d46baeebd3692dc9"}, - {file = "typed_ast-1.5.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e48bf27022897577d8479eaed64701ecaf0467182448bd95759883300ca818c8"}, - {file = "typed_ast-1.5.5-cp311-cp311-win_amd64.whl", hash = "sha256:83509f9324011c9a39faaef0922c6f720f9623afe3fe220b6d0b15638247206b"}, - {file = "typed_ast-1.5.5-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:44f214394fc1af23ca6d4e9e744804d890045d1643dd7e8229951e0ef39429b5"}, - {file = "typed_ast-1.5.5-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:118c1ce46ce58fda78503eae14b7664163aa735b620b64b5b725453696f2a35c"}, - {file = "typed_ast-1.5.5-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be4919b808efa61101456e87f2d4c75b228f4e52618621c77f1ddcaae15904fa"}, - {file = "typed_ast-1.5.5-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:fc2b8c4e1bc5cd96c1a823a885e6b158f8451cf6f5530e1829390b4d27d0807f"}, - {file = "typed_ast-1.5.5-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:16f7313e0a08c7de57f2998c85e2a69a642e97cb32f87eb65fbfe88381a5e44d"}, - {file = "typed_ast-1.5.5-cp36-cp36m-win_amd64.whl", hash = "sha256:2b946ef8c04f77230489f75b4b5a4a6f24c078be4aed241cfabe9cbf4156e7e5"}, - {file = "typed_ast-1.5.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2188bc33d85951ea4ddad55d2b35598b2709d122c11c75cffd529fbc9965508e"}, - {file = "typed_ast-1.5.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0635900d16ae133cab3b26c607586131269f88266954eb04ec31535c9a12ef1e"}, - {file = "typed_ast-1.5.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:57bfc3cf35a0f2fdf0a88a3044aafaec1d2f24d8ae8cd87c4f58d615fb5b6311"}, - {file = "typed_ast-1.5.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:fe58ef6a764de7b4b36edfc8592641f56e69b7163bba9f9c8089838ee596bfb2"}, - {file = "typed_ast-1.5.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d09d930c2d1d621f717bb217bf1fe2584616febb5138d9b3e8cdd26506c3f6d4"}, - {file = "typed_ast-1.5.5-cp37-cp37m-win_amd64.whl", hash = "sha256:d40c10326893ecab8a80a53039164a224984339b2c32a6baf55ecbd5b1df6431"}, - {file = "typed_ast-1.5.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fd946abf3c31fb50eee07451a6aedbfff912fcd13cf357363f5b4e834cc5e71a"}, - {file = "typed_ast-1.5.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ed4a1a42df8a3dfb6b40c3d2de109e935949f2f66b19703eafade03173f8f437"}, - {file = "typed_ast-1.5.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:045f9930a1550d9352464e5149710d56a2aed23a2ffe78946478f7b5416f1ede"}, - {file = "typed_ast-1.5.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:381eed9c95484ceef5ced626355fdc0765ab51d8553fec08661dce654a935db4"}, - {file = "typed_ast-1.5.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:bfd39a41c0ef6f31684daff53befddae608f9daf6957140228a08e51f312d7e6"}, - {file = "typed_ast-1.5.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8c524eb3024edcc04e288db9541fe1f438f82d281e591c548903d5b77ad1ddd4"}, - {file = "typed_ast-1.5.5-cp38-cp38-win_amd64.whl", hash = "sha256:7f58fabdde8dcbe764cef5e1a7fcb440f2463c1bbbec1cf2a86ca7bc1f95184b"}, - {file = "typed_ast-1.5.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:042eb665ff6bf020dd2243307d11ed626306b82812aba21836096d229fdc6a10"}, - {file = "typed_ast-1.5.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:622e4a006472b05cf6ef7f9f2636edc51bda670b7bbffa18d26b255269d3d814"}, - {file = "typed_ast-1.5.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1efebbbf4604ad1283e963e8915daa240cb4bf5067053cf2f0baadc4d4fb51b8"}, - {file = "typed_ast-1.5.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0aefdd66f1784c58f65b502b6cf8b121544680456d1cebbd300c2c813899274"}, - {file = "typed_ast-1.5.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:48074261a842acf825af1968cd912f6f21357316080ebaca5f19abbb11690c8a"}, - {file = "typed_ast-1.5.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:429ae404f69dc94b9361bb62291885894b7c6fb4640d561179548c849f8492ba"}, - {file = "typed_ast-1.5.5-cp39-cp39-win_amd64.whl", hash = "sha256:335f22ccb244da2b5c296e6f96b06ee9bed46526db0de38d2f0e5a6597b81155"}, - {file = "typed_ast-1.5.5.tar.gz", hash = "sha256:94282f7a354f36ef5dbce0ef3467ebf6a258e370ab33d5b40c249fa996e590dd"}, -] - [[package]] name = "typing-extensions" version = "4.7.1" @@ -1395,5 +1311,5 @@ aws = ["boto3"] [metadata] lock-version = "2.0" -python-versions = "^3.7" -content-hash = "c7f911ab7d00a33f519c8b1b13872837fc6a8572181f25ec7ff51fc3ad69977b" +python-versions = "^3.8" +content-hash = "71ce6369e72538e571ac954b5ebb4e66fa79c1752aa61af336144df577078cc4" diff --git a/pyproject.toml b/pyproject.toml index a20a5e06..120b38ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ packages = [ ] [tool.poetry.dependencies] -python = "^3.7" +python = "^3.8" requests = "^2.31.0" ai21-tokenizer = "^0.3.9" boto3 = { version = "^1.28.82", optional = true } diff --git a/tests/unittests/__init__.py b/tests/unittests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/test_dummy.py b/tests/unittests/test_dummy.py new file mode 100644 index 00000000..39b433bc --- /dev/null +++ b/tests/unittests/test_dummy.py @@ -0,0 +1,2 @@ +def test_assert(): + assert True From e4bb903ef8775d9f755c6053e35176f407386342 Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Wed, 20 Dec 2023 17:36:57 +0200 Subject: [PATCH 07/45] docs: Readme additions (#27) * fix: removed unnecessary url env var * fix: README.md CR --- README.md | 17 +++++++++-------- ai21/ai21_env_config.py | 2 -- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index cd5f74ad..03cbbae2 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ import ai21 + client = ai21.AI21Client() -+ response = client.completion(model="j2-light", prompt=prompt, max_tokens=2) ++ response = client.completion.create(model="j2-light", prompt=prompt, max_tokens=2) ``` This applies to all resources. You would now need to create a client instance and use it to call the resource method. @@ -179,9 +179,8 @@ $ export AI21_LOG_LEVEL=debug ### Other Important Environment Variables - `AI21_API_KEY` - Your API key. If not set, you must pass it to the client constructor. -- `AI21_API_URL` - The base URL of the API. Defaults to `https://api.ai21.com/v1/`. - `AI21_API_VERSION` - The API version. Defaults to `v1`. -- `AI21_API_HOST` - The API host. Defaults to `api.ai21.com`. +- `AI21_API_HOST` - The API host. Defaults to `https://api.ai21.com/v1/`. - `AI21_TIMEOUT_SEC` - The timeout for API requests. - `AI21_NUM_RETRIES` - The maximum number of retries for API requests. Defaults to `3` retries. - `AI21_AWS_REGION` - The AWS region to use for AWS clients. Defaults to `us-east-1`. @@ -235,6 +234,8 @@ AI21 Library provides convenient ways to interact with two AWS clients for use w pip install "ai21[AWS]" ``` +This will make sure you have the required dependencies installed, including `boto3 >= 1.28.82`. + ### Usage --- @@ -257,10 +258,10 @@ print(response.summary) ```python from ai21 import AI21SageMakerClient import boto3 -sm_session = boto3.Session(region_name="us-east-1") +boto_session = boto3.Session(region_name="us-east-1") client = AI21SageMakerClient( - session=sm_session, + session=boto_session, endpoint_name="j2-endpoint-name", ) ``` @@ -272,7 +273,7 @@ client = AI21SageMakerClient( ```python from ai21 import AI21BedrockClient, BedrockModelID -client = AI21BedrockClient(region='us-east-2') # region is optional, as you can use the env variable instead +client = AI21BedrockClient(region='us-east-1') # region is optional, as you can use the env variable instead response = client.completion.create( prompt="Your prompt here", model_id=BedrockModelID.J2_MID_V1, @@ -286,10 +287,10 @@ print(response.completions[0].data.text) ```python from ai21 import AI21BedrockClient, BedrockModelID import boto3 -bedrock_session = boto3.Session(region_name="us-east-2") +boto_session = boto3.Session(region_name="us-east-1") client = AI21BedrockClient( - session=bedrock_session, + session=boto_session, ) response = client.completion.create( diff --git a/ai21/ai21_env_config.py b/ai21/ai21_env_config.py index d2ea1246..9f3a46be 100644 --- a/ai21/ai21_env_config.py +++ b/ai21/ai21_env_config.py @@ -9,7 +9,6 @@ @dataclass(frozen=True) class _AI21EnvConfig: api_key: Optional[str] = None - api_url: Optional[str] = None api_version: str = DEFAULT_API_VERSION api_host: str = STUDIO_HOST organization: Optional[str] = None @@ -23,7 +22,6 @@ class _AI21EnvConfig: def from_env(cls) -> _AI21EnvConfig: return cls( api_key=os.getenv("AI21_API_KEY"), - api_url=os.getenv("AI21_API_URL"), api_version=os.getenv("AI21_API_VERSION", DEFAULT_API_VERSION), api_host=os.getenv("AI21_API_HOST", STUDIO_HOST), organization=os.getenv("AI21_ORGANIZATION"), From c455b77ce20555c530f02107f1400287e928b371 Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Mon, 25 Dec 2023 13:15:14 +0200 Subject: [PATCH 08/45] test: Unittests for 2.0.0 (#28) * test: get_tokenizer tests * fix: cases * test: Added some unittests to resources * fix: rename var * test: Added ai21 studio client tsts * fix: rename files * fix: Added types * test: added test to http * fix: removed unnecessary auth param * test: Added tests * test: Added sagemaker * test: Created a single session per instance * ci: removed unnecessary action --- .github/workflows/quality-checks.yml | 27 ---- ...1_studio_client.py => ai21_http_client.py} | 33 +++- ai21/clients/studio/ai21_client.py | 33 ++-- ai21/clients/studio/resources/studio_chat.py | 2 +- .../studio/resources/studio_completion.py | 2 +- .../studio/resources/studio_library.py | 4 +- ai21/http_client.py | 62 +++++--- ai21/resources/bases/chat_base.py | 2 +- ai21/resources/studio_resource.py | 7 +- ai21/services/sagemaker.py | 11 +- ai21/tokenizers/factory.py | 4 +- poetry.lock | 19 ++- pyproject.toml | 1 + tests/unittests/clients/__init__.py | 0 tests/unittests/clients/studio/__init__.py | 0 .../clients/studio/resources/__init__.py | 0 .../clients/studio/resources/conftest.py | 122 +++++++++++++++ .../studio/resources/test_studio_resources.py | 80 ++++++++++ tests/unittests/conftest.py | 12 ++ tests/unittests/services/__init__.py | 0 tests/unittests/services/sagemaker_stub.py | 12 ++ tests/unittests/services/test_sagemaker.py | 44 ++++++ tests/unittests/test_ai21_http_client.py | 142 ++++++++++++++++++ tests/unittests/test_dummy.py | 2 - tests/unittests/tokenizers/__init__.py | 0 .../tokenizers/test_ai21_tokenizer.py | 25 +++ 26 files changed, 558 insertions(+), 88 deletions(-) delete mode 100644 .github/workflows/quality-checks.yml rename ai21/{ai21_studio_client.py => ai21_http_client.py} (70%) create mode 100644 tests/unittests/clients/__init__.py create mode 100644 tests/unittests/clients/studio/__init__.py create mode 100644 tests/unittests/clients/studio/resources/__init__.py create mode 100644 tests/unittests/clients/studio/resources/conftest.py create mode 100644 tests/unittests/clients/studio/resources/test_studio_resources.py create mode 100644 tests/unittests/conftest.py create mode 100644 tests/unittests/services/__init__.py create mode 100644 tests/unittests/services/sagemaker_stub.py create mode 100644 tests/unittests/services/test_sagemaker.py create mode 100644 tests/unittests/test_ai21_http_client.py delete mode 100644 tests/unittests/test_dummy.py create mode 100644 tests/unittests/tokenizers/__init__.py create mode 100644 tests/unittests/tokenizers/test_ai21_tokenizer.py diff --git a/.github/workflows/quality-checks.yml b/.github/workflows/quality-checks.yml deleted file mode 100644 index ba62006a..00000000 --- a/.github/workflows/quality-checks.yml +++ /dev/null @@ -1,27 +0,0 @@ -# yaml-language-server: $schema=https://json.schemastore.org/github-workflow.json - -name: Quality Checks -concurrency: - group: Quality-Checks-${{ github.head_ref }} - cancel-in-progress: true -on: - pull_request: -jobs: - quality-checks: - runs-on: ubuntu-20.04 - timeout-minutes: 10 - steps: - - name: Checkout - uses: actions/checkout@v3.2.0 - with: - fetch-depth: 0 - - name: Pre-commit - uses: pre-commit/action@v3.0.0 - with: - extra_args: --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }} - # - name: CODEOWNERS validator - # uses: mszostok/codeowners-validator@v0.6.0 - # with: - # checks: files,duppatterns,syntax,owners - # experimental_checks: notowned - # github_access_token: ${{ secrets.GH_PAT_RO }} diff --git a/ai21/ai21_studio_client.py b/ai21/ai21_http_client.py similarity index 70% rename from ai21/ai21_studio_client.py rename to ai21/ai21_http_client.py index 0ceef023..9bd3cb82 100644 --- a/ai21/ai21_studio_client.py +++ b/ai21/ai21_http_client.py @@ -1,12 +1,14 @@ +import io from typing import Optional, Dict, Any + from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.errors import MissingApiKeyException from ai21.http_client import HttpClient from ai21.version import VERSION -class AI21StudioClient: +class AI21HTTPClient: def __init__( self, *, @@ -17,7 +19,9 @@ def __init__( timeout_sec: Optional[int] = None, num_retries: Optional[int] = None, organization: Optional[str] = None, + application: Optional[str] = None, via: Optional[str] = None, + http_client: Optional[HttpClient] = None, env_config: _AI21EnvConfig = AI21EnvConfig, ): self._env_config = env_config @@ -32,12 +36,11 @@ def __init__( self._timeout_sec = timeout_sec or self._env_config.timeout_sec self._num_retries = num_retries or self._env_config.num_retries self._organization = organization or self._env_config.organization - self._application = self._env_config.application + self._application = application or self._env_config.application self._via = via headers = self._build_headers(passed_headers=headers) - - self.http_client = HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers) + self._http_client = self._init_http_client(http_client=http_client, headers=headers) def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, Any]: headers = { @@ -53,6 +56,18 @@ def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, return headers + def _init_http_client(self, http_client: Optional[HttpClient], headers: Dict[str, Any]) -> HttpClient: + if http_client is None: + return HttpClient( + timeout_sec=self._timeout_sec, + num_retries=self._num_retries, + headers=headers, + ) + + http_client.add_headers(headers) + + return http_client + def _build_user_agent(self) -> str: user_agent = f"ai21 studio SDK {VERSION}" @@ -67,8 +82,14 @@ def _build_user_agent(self) -> str: return user_agent - def execute_http_request(self, method: str, url: str, params: Optional[Dict] = None, files=None): - return self.http_client.execute_http_request(method=method, url=url, params=params, files=files) + def execute_http_request( + self, + method: str, + url: str, + params: Optional[Dict] = None, + files: Optional[Dict[str, io.TextIOWrapper]] = None, + ): + return self._http_client.execute_http_request(method=method, url=url, params=params, files=files) def get_base_url(self) -> str: return f"{self._api_host}/studio/{self._api_version}" diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index 025a3bab..5f6cee06 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -1,6 +1,6 @@ from typing import Optional, Any, Dict -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_chat import StudioChat from ai21.clients.studio.resources.studio_completion import StudioCompletion @@ -14,6 +14,7 @@ from ai21.clients.studio.resources.studio_segmentation import StudioSegmentation from ai21.clients.studio.resources.studio_summarize import StudioSummarize from ai21.clients.studio.resources.studio_summarize_by_segment import StudioSummarizeBySegment +from ai21.http_client import HttpClient from ai21.tokenizers.ai21_tokenizer import AI21Tokenizer from ai21.tokenizers.factory import get_tokenizer @@ -33,29 +34,31 @@ def __init__( timeout_sec: Optional[float] = None, num_retries: Optional[int] = None, via: Optional[str] = None, + http_client: Optional[HttpClient] = None, **kwargs, ): - studio_client = AI21StudioClient( + self._http_client = AI21HTTPClient( api_key=api_key, api_host=api_host, headers=headers, timeout_sec=timeout_sec, num_retries=num_retries, via=via, + http_client=http_client, ) - self.completion = StudioCompletion(studio_client) - self.chat = StudioChat(studio_client) - self.summarize = StudioSummarize(studio_client) - self.embed = StudioEmbed(studio_client) - self.gec = StudioGEC(studio_client) - self.improvements = StudioImprovements(studio_client) - self.paraphrase = StudioParaphrase(studio_client) - self.summarize_by_segment = StudioSummarizeBySegment(studio_client) - self.custom_model = StudioCustomModel(studio_client) - self.dataset = StudioDataset(studio_client) - self.answer = StudioAnswer(studio_client) - self.library = StudioLibrary(studio_client) - self.segmentation = StudioSegmentation(studio_client) + self.completion = StudioCompletion(self._http_client) + self.chat = StudioChat(self._http_client) + self.summarize = StudioSummarize(self._http_client) + self.embed = StudioEmbed(self._http_client) + self.gec = StudioGEC(self._http_client) + self.improvements = StudioImprovements(self._http_client) + self.paraphrase = StudioParaphrase(self._http_client) + self.summarize_by_segment = StudioSummarizeBySegment(self._http_client) + self.custom_model = StudioCustomModel(self._http_client) + self.dataset = StudioDataset(self._http_client) + self.answer = StudioAnswer(self._http_client) + self.library = StudioLibrary(self._http_client) + self.segmentation = StudioSegmentation(self._http_client) def count_tokens(self, text: str) -> int: # We might want to cache the tokenizer instance within the class diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index f1dab12b..8fe1bca4 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -39,6 +39,6 @@ def create( presence_penalty=presence_penalty, count_penalty=count_penalty, ) - url = f"{self._client.get_base_url()}/{model}/{self._module_name}" + url = f"{self._client.get_base_url()}/{model}/{self._MODULE_NAME}" response = self._post(url=url, body=body) return self._json_to_response(response) diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index 84e179b3..48f85fa7 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -15,7 +15,7 @@ def create( num_results: Optional[int] = 1, min_tokens: Optional[int] = 0, temperature: Optional[float] = 0.7, - top_p: Optional[int] = 1, + top_p: Optional[float] = 1, top_k_return: Optional[int] = 0, custom_model: Optional[str] = None, experimental_mode: bool = False, diff --git a/ai21/clients/studio/resources/studio_library.py b/ai21/clients/studio/resources/studio_library.py index ae785a85..42daedbb 100644 --- a/ai21/clients/studio/resources/studio_library.py +++ b/ai21/clients/studio/resources/studio_library.py @@ -1,6 +1,6 @@ from typing import Optional, List -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from ai21.resources.responses.file_response import FileResponse from ai21.resources.responses.library_answer_response import LibraryAnswerResponse from ai21.resources.responses.library_search_response import LibrarySearchResponse @@ -10,7 +10,7 @@ class StudioLibrary(StudioResource): _module_name = "library/files" - def __init__(self, client: AI21StudioClient): + def __init__(self, client: AI21HTTPClient): super().__init__(client) self.files = LibraryFiles(client) self.search = LibrarySearch(client) diff --git a/ai21/http_client.py b/ai21/http_client.py index 7f9b0286..00692e07 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -1,5 +1,6 @@ +import io import json -from typing import Optional, Dict +from typing import Optional, Dict, Any import requests from requests.adapters import HTTPAdapter, Retry, RetryError @@ -55,34 +56,35 @@ def requests_retry_session(session, retries=0): class HttpClient: - def __init__(self, timeout_sec: int = None, num_retries: int = None, headers: Dict = None): - self.timeout_sec = timeout_sec if timeout_sec is not None else DEFAULT_TIMEOUT_SEC - self.num_retries = num_retries if num_retries is not None else DEFAULT_NUM_RETRIES - self.headers = headers if headers is not None else {} - self.apply_retry_policy = self.num_retries > 0 + def __init__( + self, + session: Optional[requests.Session] = None, + timeout_sec: int = None, + num_retries: int = None, + headers: Dict = None, + ): + self._timeout_sec = timeout_sec or DEFAULT_TIMEOUT_SEC + self._num_retries = num_retries or DEFAULT_NUM_RETRIES + self._headers = headers or {} + self._apply_retry_policy = self._num_retries > 0 + self._session = self._init_session(session) def execute_http_request( self, method: str, url: str, params: Optional[Dict] = None, - files=None, - auth=None, + files: Optional[Dict[str, io.TextIOWrapper]] = None, ): - session = ( - requests_retry_session(requests.Session(), retries=self.num_retries) - if self.apply_retry_policy - else requests.Session() - ) - timeout = self.timeout_sec - headers = self.headers + timeout = self._timeout_sec + headers = self._headers data = json.dumps(params).encode() logger.info(f"Calling {method} {url} {headers} {data}") try: if method == "GET": - response = session.request( - method, - url, + response = self._session.request( + method=method, + url=url, headers=headers, timeout=timeout, params=params, @@ -96,23 +98,22 @@ def execute_http_request( headers.pop( "Content-Type" ) # multipart/form-data 'Content-Type' is being added when passing rb files and payload - response = session.request( - method, - url, + response = self._session.request( + method=method, + url=url, headers=headers, data=params, files=files, timeout=timeout, - auth=auth, ) else: - response = session.request(method, url, headers=headers, data=data, timeout=timeout, auth=auth) + response = self._session.request(method=method, url=url, headers=headers, data=data, timeout=timeout) except ConnectionError as connection_error: logger.error(f"Calling {method} {url} failed with ConnectionError: {connection_error}") raise connection_error except RetryError as retry_error: logger.error( - f"Calling {method} {url} failed with RetryError after {self.num_retries} attempts: {retry_error}" + f"Calling {method} {url} failed with RetryError after {self._num_retries} attempts: {retry_error}" ) raise retry_error except Exception as exception: @@ -124,3 +125,16 @@ def execute_http_request( handle_non_success_response(response.status_code, response.text) return response.json() + + def _init_session(self, session: Optional[requests.Session]) -> requests.Session: + if session is not None: + return session + + return ( + requests_retry_session(requests.Session(), retries=self._num_retries) + if self._apply_retry_policy + else requests.Session() + ) + + def add_headers(self, headers: Dict[str, Any]) -> None: + self._headers.update(headers) diff --git a/ai21/resources/bases/chat_base.py b/ai21/resources/bases/chat_base.py index f85270ee..e2a67c0d 100644 --- a/ai21/resources/bases/chat_base.py +++ b/ai21/resources/bases/chat_base.py @@ -11,7 +11,7 @@ class Message: class Chat(ABC): - _module_name = "chat" + _MODULE_NAME = "chat" @abstractmethod def create( diff --git a/ai21/resources/studio_resource.py b/ai21/resources/studio_resource.py index a467bd94..7752be91 100644 --- a/ai21/resources/studio_resource.py +++ b/ai21/resources/studio_resource.py @@ -1,20 +1,21 @@ from __future__ import annotations +import io from abc import ABC from typing import Any, Dict, Optional -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient class StudioResource(ABC): - def __init__(self, client: AI21StudioClient): + def __init__(self, client: AI21HTTPClient): self._client = client def _post( self, url: str, body: Dict[str, Any], - files: Optional[Dict[str, Any]] = None, + files: Optional[Dict[str, io.TextIOWrapper]] = None, ) -> Dict[str, Any]: return self._client.execute_http_request( method="POST", diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py index fa811541..b1387622 100644 --- a/ai21/services/sagemaker.py +++ b/ai21/services/sagemaker.py @@ -1,6 +1,6 @@ from typing import List -from ai21.ai21_studio_client import AI21StudioClient +from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.sagemaker.constants import ( SAGEMAKER_MODEL_PACKAGE_NAMES, ) @@ -18,7 +18,7 @@ class SageMaker: def get_model_package_arn(cls, model_name: str, region: str, version: str = LATEST_VERSION_STR) -> str: _assert_model_package_exists(model_name=model_name, region=region) - client = AI21StudioClient() + client = cls._create_ai21_http_client() response = client.execute_http_request( method="POST", @@ -40,7 +40,8 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE @classmethod def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: _assert_model_package_exists(model_name=model_name, region=region) - client = AI21StudioClient() + + client = cls._create_ai21_http_client() response = client.execute_http_request( method="POST", @@ -53,6 +54,10 @@ def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: return response["versions"] + @classmethod + def _create_ai21_http_client(cls) -> AI21HTTPClient: + return AI21HTTPClient() + def _assert_model_package_exists(model_name, region): if model_name not in SAGEMAKER_MODEL_PACKAGE_NAMES: diff --git a/ai21/tokenizers/factory.py b/ai21/tokenizers/factory.py index fd229231..cd728f77 100644 --- a/ai21/tokenizers/factory.py +++ b/ai21/tokenizers/factory.py @@ -16,6 +16,6 @@ def get_tokenizer() -> AI21Tokenizer: global _cached_tokenizer if _cached_tokenizer is None: - _cached_tokenizer = Tokenizer.get_tokenizer() + _cached_tokenizer = AI21Tokenizer(Tokenizer.get_tokenizer()) - return AI21Tokenizer(_cached_tokenizer) + return _cached_tokenizer diff --git a/poetry.lock b/poetry.lock index 334d859e..cd17ee01 100644 --- a/poetry.lock +++ b/poetry.lock @@ -848,6 +848,23 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-mock" +version = "3.12.0" +description = "Thin-wrapper around the mock package for easier use with pytest" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-mock-3.12.0.tar.gz", hash = "sha256:31a40f038c22cad32287bb43932054451ff5583ff094bca6f675df2f8bc1a6e9"}, + {file = "pytest_mock-3.12.0-py3-none-any.whl", hash = "sha256:0972719a7263072da3a21c7f4773069bcc7486027d7e8e1f81d98a47e701bc4f"}, +] + +[package.dependencies] +pytest = ">=5.0" + +[package.extras] +dev = ["pre-commit", "pytest-asyncio", "tox"] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -1312,4 +1329,4 @@ aws = ["boto3"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "71ce6369e72538e571ac954b5ebb4e66fa79c1752aa61af336144df577078cc4" +content-hash = "39ea6a4fd93efce593b30be52de954f1d6ab4c2d39745a9541067a5af5f37a21" diff --git a/pyproject.toml b/pyproject.toml index 120b38ae..060d165b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ safety = "*" ruff = "*" python-semantic-release = "^8.5.0" pytest = "^7.4.3" +pytest-mock = "^3.12.0" [tool.poetry.extras] AWS = ["boto3"] diff --git a/tests/unittests/clients/__init__.py b/tests/unittests/clients/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/clients/studio/__init__.py b/tests/unittests/clients/studio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/clients/studio/resources/__init__.py b/tests/unittests/clients/studio/resources/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py new file mode 100644 index 00000000..1a921fef --- /dev/null +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -0,0 +1,122 @@ +import pytest +from pytest_mock import MockerFixture + +from ai21 import AnswerResponse, ChatResponse, CompletionsResponse +from ai21.ai21_http_client import AI21HTTPClient +from ai21.clients.studio.resources.studio_answer import StudioAnswer +from ai21.clients.studio.resources.studio_chat import StudioChat +from ai21.clients.studio.resources.studio_completion import StudioCompletion +from ai21.resources.responses.chat_response import ChatOutput, FinishReason +from ai21.resources.responses.completion_response import Prompt, Completion, CompletionData, CompletionFinishReason + + +@pytest.fixture +def mock_ai21_studio_client(mocker: MockerFixture) -> AI21HTTPClient: + return mocker.MagicMock(spec=AI21HTTPClient) + + +def get_studio_answer(): + _DUMMY_CONTEXT = "What is the answer to life, the universe and everything?" + _DUMMY_QUESTION = "What is the answer?" + + return ( + StudioAnswer, + {"context": _DUMMY_CONTEXT, "question": _DUMMY_QUESTION}, + "answer", + { + "answerLength": None, + "context": _DUMMY_CONTEXT, + "mode": None, + "question": _DUMMY_QUESTION, + }, + AnswerResponse(id="some-id", answer_in_context=True, answer="42"), + ) + + +def get_studio_chat(): + _DUMMY_MODEL = "dummy-chat-model" + _DUMMY_MESSAGES = [ + { + "text": "Hello, I need help with a signup process.", + "role": "user", + "name": "Alice", + }, + { + "text": "Hi Alice, I can help you with that. What seems to be the problem?", + "role": "assistant", + "name": "Bob", + }, + ] + _DUMMY_SYSTEM = "You're a support engineer in a SaaS company" + + return ( + StudioChat, + {"model": _DUMMY_MODEL, "messages": _DUMMY_MESSAGES, "system": _DUMMY_SYSTEM}, + f"{_DUMMY_MODEL}/chat", + { + "model": _DUMMY_MODEL, + "system": _DUMMY_SYSTEM, + "messages": _DUMMY_MESSAGES, + "temperature": 0.7, + "maxTokens": 300, + "minTokens": 0, + "numResults": 1, + "topP": 1.0, + "topKReturn": 0, + "stopSequences": None, + "frequencyPenalty": None, + "presencePenalty": None, + "countPenalty": None, + }, + ChatResponse( + outputs=[ + ChatOutput( + text="Hello, I need help with a signup process.", + role="user", + finish_reason=FinishReason(reason="dummy_reason", length=1, sequence="1"), + ) + ] + ), + ) + + +def get_studio_completion(): + _DUMMY_MODEL = "dummy-completion-model" + _DUMMY_PROMPT = "dummy-prompt" + + return ( + StudioCompletion, + {"model": _DUMMY_MODEL, "prompt": _DUMMY_PROMPT}, + f"{_DUMMY_MODEL}/complete", + { + "model": _DUMMY_MODEL, + "prompt": _DUMMY_PROMPT, + "temperature": 0.7, + "maxTokens": None, + "minTokens": 0, + "epoch": None, + "numResults": 1, + "topP": 1, + "customModel": None, + "experimentalModel": False, + "topKReturn": 0, + "stopSequences": [], + "frequencyPenalty": None, + "presencePenalty": None, + "countPenalty": None, + }, + CompletionsResponse( + id="some-id", + completions=[ + Completion( + data=CompletionData(text="dummy-completion", tokens=[]), + finish_reason=CompletionFinishReason(reason="dummy_reason", length=1), + ) + ], + prompt=Prompt(text="dummy-prompt"), + ), + ) + + +def get_studio_custom_model(): + pass diff --git a/tests/unittests/clients/studio/resources/test_studio_resources.py b/tests/unittests/clients/studio/resources/test_studio_resources.py new file mode 100644 index 00000000..0e4de3af --- /dev/null +++ b/tests/unittests/clients/studio/resources/test_studio_resources.py @@ -0,0 +1,80 @@ +from typing import TypeVar, Callable + +import pytest + +from ai21 import AnswerResponse +from ai21.ai21_http_client import AI21HTTPClient +from ai21.clients.studio.resources.studio_answer import StudioAnswer +from ai21.resources.studio_resource import StudioResource +from tests.unittests.clients.studio.resources.conftest import get_studio_answer, get_studio_chat, get_studio_completion + +_BASE_URL = "https://test.api.ai21.com/studio/v1" +_DUMMY_CONTEXT = "What is the answer to life, the universe and everything?" +_DUMMY_QUESTION = "What is the answer?" + +T = TypeVar("T", bound=StudioResource) + + +class TestStudioResources: + @pytest.mark.parametrize( + ids=[ + "studio_answer", + "studio_chat", + "studio_completion", + ], + argnames=["studio_resource", "function_body", "url_suffix", "expected_body", "expected_response"], + argvalues=[ + (get_studio_answer()), + (get_studio_chat()), + (get_studio_completion()), + ], + ) + def test__create__should_return_answer_response( + self, + studio_resource: Callable[[AI21HTTPClient], T], + function_body, + url_suffix: str, + expected_body, + expected_response, + mock_ai21_studio_client: AI21HTTPClient, + ): + mock_ai21_studio_client.execute_http_request.return_value = expected_response.to_dict() + mock_ai21_studio_client.get_base_url.return_value = _BASE_URL + + resource = studio_resource(mock_ai21_studio_client) + + actual_response = resource.create( + **function_body, + ) + + assert actual_response == expected_response + mock_ai21_studio_client.execute_http_request.assert_called_with( + method="POST", + url=f"{_BASE_URL}/{url_suffix}", + params=expected_body, + files=None, + ) + + def test__create__when_pass_kwargs__should_not_pass_to_request(self, mock_ai21_studio_client: AI21HTTPClient): + expected_answer = AnswerResponse(id="some-id", answer_in_context=True, answer="42") + mock_ai21_studio_client.execute_http_request.return_value = expected_answer.to_dict() + mock_ai21_studio_client.get_base_url.return_value = _BASE_URL + studio_answer = StudioAnswer(mock_ai21_studio_client) + + studio_answer.create( + context=_DUMMY_CONTEXT, + question=_DUMMY_QUESTION, + some_dummy_kwargs="some_dummy_value", + ) + + mock_ai21_studio_client.execute_http_request.assert_called_with( + method="POST", + url=_BASE_URL + "/answer", + params={ + "answerLength": None, + "context": _DUMMY_CONTEXT, + "mode": None, + "question": _DUMMY_QUESTION, + }, + files=None, + ) diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py new file mode 100644 index 00000000..02e4d467 --- /dev/null +++ b/tests/unittests/conftest.py @@ -0,0 +1,12 @@ +import pytest +import requests + + +@pytest.fixture +def dummy_api_host() -> str: + return "http://test_host" + + +@pytest.fixture +def mock_requests_session(mocker) -> requests.Session: + return mocker.Mock(spec=requests.Session) diff --git a/tests/unittests/services/__init__.py b/tests/unittests/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/services/sagemaker_stub.py b/tests/unittests/services/sagemaker_stub.py new file mode 100644 index 00000000..16cd98c2 --- /dev/null +++ b/tests/unittests/services/sagemaker_stub.py @@ -0,0 +1,12 @@ +from unittest.mock import Mock + +from ai21 import SageMaker +from ai21.ai21_http_client import AI21HTTPClient + + +class SageMakerStub(SageMaker): + ai21_http_client = Mock(spec=AI21HTTPClient) + + @classmethod + def _create_ai21_http_client(cls): + return cls.ai21_http_client diff --git a/tests/unittests/services/test_sagemaker.py b/tests/unittests/services/test_sagemaker.py new file mode 100644 index 00000000..a92c23fe --- /dev/null +++ b/tests/unittests/services/test_sagemaker.py @@ -0,0 +1,44 @@ +import pytest + +from ai21.errors import ModelPackageDoesntExistException +from tests.unittests.services.sagemaker_stub import SageMakerStub + +_DUMMY_ARN = "some-model-package-id1" +_DUMMY_VERSIONS = ["1.0.0", "1.0.1"] + + +class TestSageMakerService: + def test__get_model_package_arn__should_return_model_package_arn(self): + expected_response = { + "arn": _DUMMY_ARN, + "versions": _DUMMY_VERSIONS, + } + SageMakerStub.ai21_http_client.execute_http_request.return_value = expected_response + + actual_model_package_arn = SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") + + assert actual_model_package_arn == _DUMMY_ARN + + def test__get_model_package_arn__when_no_arn__should_raise_error(self): + SageMakerStub.ai21_http_client.execute_http_request.return_value = {"arn": []} + + with pytest.raises(ModelPackageDoesntExistException): + SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") + + def test__list_model_package_versions__should_return_model_package_arn(self): + expected_response = { + "versions": _DUMMY_VERSIONS, + } + SageMakerStub.ai21_http_client.execute_http_request.return_value = expected_response + + actual_model_package_arn = SageMakerStub.list_model_package_versions(model_name="j2-mid", region="us-east-1") + + assert actual_model_package_arn == _DUMMY_VERSIONS + + def test__list_model_package_versions__when_model_package_not_available__should_raise_an_error(self): + with pytest.raises(ModelPackageDoesntExistException): + SageMakerStub.list_model_package_versions(model_name="openai", region="us-east-1") + + def test__get_model_package_arn__when_model_package_not_available__should_raise_an_error(self): + with pytest.raises(ModelPackageDoesntExistException): + SageMakerStub.get_model_package_arn(model_name="openai", region="us-east-1") diff --git a/tests/unittests/test_ai21_http_client.py b/tests/unittests/test_ai21_http_client.py new file mode 100644 index 00000000..67c4f197 --- /dev/null +++ b/tests/unittests/test_ai21_http_client.py @@ -0,0 +1,142 @@ +from typing import Optional + +import pytest +import requests + +from ai21.ai21_http_client import AI21HTTPClient +from ai21.http_client import HttpClient +from ai21.version import VERSION + +_DUMMY_API_KEY = "dummy_key" +_EXPECTED_GET_HEADERS = { + "Authorization": "Bearer dummy_key", + "Content-Type": "application/json", + "User-Agent": f"ai21 studio SDK {VERSION}", +} + +_EXPECTED_POST_FILE_HEADERS = { + "Authorization": "Bearer dummy_key", + "User-Agent": f"ai21 studio SDK {VERSION}", +} + + +class MockResponse: + def __init__(self, json_data, status_code): + self.json_data = json_data + self.status_code = status_code + + def json(self): + return self.json_data + + +class TestAI21StudioClient: + @pytest.mark.parametrize( + ids=[ + "when_pass_only_via__should_include_via_in_user_agent", + "when_pass_only_application__should_include_application_in_user_agent", + "when_pass_organization__should_include_organization_in_user_agent", + "when_pass_all_user_agent_relevant_params__should_include_them_in_user_agent", + ], + argnames=["via", "application", "organization", "expected_user_agent"], + argvalues=[ + ("langchain", None, None, f"ai21 studio SDK {VERSION} via: langchain"), + (None, "studio", None, f"ai21 studio SDK {VERSION} application: studio"), + (None, None, "ai21", f"ai21 studio SDK {VERSION} organization: ai21"), + ( + "langchain", + "studio", + "ai21", + f"ai21 studio SDK {VERSION} organization: ai21 application: studio via: langchain", + ), + ], + ) + def test__build_headers__user_agent( + self, via: Optional[str], application: Optional[str], organization: Optional[str], expected_user_agent: str + ): + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, via=via, application=application, organization=organization) + assert client._http_client._headers["User-Agent"] == expected_user_agent + + def test__build_headers__authorization(self): + client = AI21HTTPClient(api_key=_DUMMY_API_KEY) + assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" + + def test__build_headers__when_pass_headers__should_append(self): + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, headers={"foo": "bar"}) + assert client._http_client._headers["foo"] == "bar" + assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" + + @pytest.mark.parametrize( + ids=[ + "when_api_host_is_not_set__should_return_default", + "when_api_host_is_set__should_return_set_value", + ], + argnames=["api_host", "expected_api_host"], + argvalues=[ + (None, "https://api.ai21.com/studio/v1"), + ("http://test_host", "http://test_host/studio/v1"), + ], + ) + def test__get_base_url(self, api_host: Optional[str], expected_api_host: str): + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") + assert client.get_base_url() == expected_api_host + + @pytest.mark.parametrize( + ids=[ + "when_making_request__should_send_appropriate_parameters", + "when_making_request_with_files__should_send_appropriate_post_request", + ], + argnames=["params", "headers"], + argvalues=[ + ({"method": "GET", "url": "test_url", "params": {"foo": "bar"}}, _EXPECTED_GET_HEADERS), + ( + {"method": "POST", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}}, + _EXPECTED_POST_FILE_HEADERS, + ), + ], + ) + def test__execute_http_request__( + self, + params, + headers, + dummy_api_host: str, + mock_requests_session: requests.Session, + ): + response_json = {"test_key": "test_value"} + mock_requests_session.request.return_value = MockResponse(response_json, 200) + + http_client = HttpClient(session=mock_requests_session) + client = AI21HTTPClient( + http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1" + ) + + response = client.execute_http_request(**params) + assert response == response_json + + if "files" in params: + # We split it because when calling requests with "files", "params" is turned into "data" + mock_requests_session.request.assert_called_once_with( + timeout=300, + headers=headers, + files=params["files"], + data=params["params"], + url=params["url"], + method=params["method"], + ) + else: + mock_requests_session.request.assert_called_once_with(timeout=300, headers=headers, **params) + + def test__execute_http_request__when_files_with_put_method__should_raise_value_error( + self, + dummy_api_host: str, + mock_requests_session: requests.Session, + ): + response_json = {"test_key": "test_value"} + http_client = HttpClient(session=mock_requests_session) + client = AI21HTTPClient( + http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1" + ) + + mock_requests_session.request.return_value = MockResponse(response_json, 200) + with pytest.raises(ValueError): + params = {"method": "PUT", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}} + client.execute_http_request(**params) diff --git a/tests/unittests/test_dummy.py b/tests/unittests/test_dummy.py deleted file mode 100644 index 39b433bc..00000000 --- a/tests/unittests/test_dummy.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_assert(): - assert True diff --git a/tests/unittests/tokenizers/__init__.py b/tests/unittests/tokenizers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/tokenizers/test_ai21_tokenizer.py b/tests/unittests/tokenizers/test_ai21_tokenizer.py new file mode 100644 index 00000000..33f89d9e --- /dev/null +++ b/tests/unittests/tokenizers/test_ai21_tokenizer.py @@ -0,0 +1,25 @@ +from ai21.tokenizers.factory import get_tokenizer + + +class TestAI21Tokenizer: + def test__count_tokens__should_return_number_of_tokens(self): + expected_number_of_tokens = 8 + tokenizer = get_tokenizer() + + actual_number_of_tokens = tokenizer.count_tokens("Text to Tokenize - Hello world!") + + assert actual_number_of_tokens == expected_number_of_tokens + + def test__tokenize__should_return_list_of_tokens(self): + expected_tokens = ["▁Text", "▁to", "▁Token", "ize", "▁-", "▁Hello", "▁world", "!"] + tokenizer = get_tokenizer() + + actual_tokens = tokenizer.tokenize("Text to Tokenize - Hello world!") + + assert actual_tokens == expected_tokens + + def test__tokenizer__should_be_singleton__when_called_twice(self): + tokenizer1 = get_tokenizer() + tokenizer2 = get_tokenizer() + + assert tokenizer1 is tokenizer2 From d6f73b5f3db3b234334b8e430d8a33633fb0247c Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Wed, 27 Dec 2023 14:08:35 +0200 Subject: [PATCH 09/45] fix: Feedback fixes (#29) * test: get_tokenizer tests * fix: cases * test: Added some unittests to resources * fix: rename var * test: Added ai21 studio client tsts * fix: rename files * fix: Added types * test: added test to http * fix: removed unnecessary auth param * test: Added tests * test: Added sagemaker * test: Created a single session per instance * ci: removed unnecessary action * fix: errors * fix: error renames * fix: rename upload * fix: rename type * fix: rename variable * fix: removed experimental * test: fixed * test: Added some unittests to resources * test: Added ai21 studio client tsts * fix: rename files * fix: Added types * test: added test to http * fix: removed unnecessary auth param * test: Added tests * test: Added sagemaker * test: Created a single session per instance * fix: errors * fix: error renames * fix: rename upload * fix: rename type * fix: rename variable * fix: removed experimental * test: fixed --- README.md | 4 +- ai21/__init__.py | 15 ++++- ai21/ai21_env_config.py | 4 -- ai21/ai21_http_client.py | 14 ++--- ai21/clients/bedrock/bedrock_session.py | 4 +- .../clients/studio/resources/studio_answer.py | 2 +- ai21/clients/studio/resources/studio_chat.py | 2 +- .../studio/resources/studio_completion.py | 5 -- .../studio/resources/studio_dataset.py | 2 +- .../studio/resources/studio_improvements.py | 4 +- .../studio/resources/studio_library.py | 2 +- ai21/errors.py | 59 ++++--------------- ai21/http_client.py | 11 ++-- ai21/resources/bases/answer_base.py | 2 +- ai21/resources/bases/chat_base.py | 2 +- ai21/resources/bases/completion_base.py | 3 - ai21/resources/bases/dataset_base.py | 2 +- ai21/resources/studio_resource.py | 5 +- ai21/services/sagemaker.py | 6 +- examples/studio/dataset.py | 2 +- examples/studio/library.py | 2 +- .../clients/studio/resources/conftest.py | 1 - tests/unittests/services/test_sagemaker.py | 8 +-- 23 files changed, 59 insertions(+), 102 deletions(-) diff --git a/README.md b/README.md index 03cbbae2..ac7edc0d 100644 --- a/README.md +++ b/README.md @@ -150,7 +150,7 @@ from ai21 import AI21Client client = AI21Client() -file_id = client.library.files.upload( +file_id = client.library.files.create( file_path="path/to/file", path="path/to/file/in/library", labels=["label1", "label2"], @@ -213,7 +213,7 @@ try: except ai21_errors.AI21ServerError as e: print("Server error and could not be reached") print(e.details) -except ai21_errors.TooManyRequests as e: +except ai21_errors.TooManyRequestsError as e: print("A 429 status code was returned. Slow down on the requests") except AI21APIError as e: print("A non 200 status code error. For more error types see ai21.errors") diff --git a/ai21/__init__.py b/ai21/__init__.py index d614a1ca..6c5fb3e9 100644 --- a/ai21/__init__.py +++ b/ai21/__init__.py @@ -1,7 +1,14 @@ from typing import Any from ai21.clients.studio.ai21_client import AI21Client -from ai21.errors import AI21APIError, AI21APITimeoutError +from ai21.errors import ( + AI21APIError, + APITimeoutError, + MissingApiKeyError, + ModelPackageDoesntExistError, + AI21Error, + TooManyRequestsError, +) from ai21.logger import setup_logger from ai21.resources.responses.answer_response import AnswerResponse from ai21.resources.responses.chat_response import ChatResponse @@ -60,7 +67,11 @@ def __getattr__(name: str) -> Any: __all__ = [ "AI21Client", "AI21APIError", - "AI21APITimeoutError", + "APITimeoutError", + "AI21Error", + "MissingApiKeyError", + "ModelPackageDoesntExistError", + "TooManyRequestsError", "AI21BedrockClient", "AI21SageMakerClient", "BedrockModelID", diff --git a/ai21/ai21_env_config.py b/ai21/ai21_env_config.py index 9f3a46be..01ef3501 100644 --- a/ai21/ai21_env_config.py +++ b/ai21/ai21_env_config.py @@ -11,8 +11,6 @@ class _AI21EnvConfig: api_key: Optional[str] = None api_version: str = DEFAULT_API_VERSION api_host: str = STUDIO_HOST - organization: Optional[str] = None - application: Optional[str] = None timeout_sec: Optional[int] = None num_retries: Optional[int] = None aws_region: Optional[str] = None @@ -24,8 +22,6 @@ def from_env(cls) -> _AI21EnvConfig: api_key=os.getenv("AI21_API_KEY"), api_version=os.getenv("AI21_API_VERSION", DEFAULT_API_VERSION), api_host=os.getenv("AI21_API_HOST", STUDIO_HOST), - organization=os.getenv("AI21_ORGANIZATION"), - application=os.getenv("AI21_APPLICATION"), timeout_sec=os.getenv("AI21_TIMEOUT_SEC"), num_retries=os.getenv("AI21_NUM_RETRIES"), aws_region=os.getenv("AI21_AWS_REGION", "us-east-1"), diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index 9bd3cb82..68007654 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -1,9 +1,7 @@ -import io -from typing import Optional, Dict, Any - +from typing import Optional, Dict, Any, BinaryIO from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig -from ai21.errors import MissingApiKeyException +from ai21.errors import MissingApiKeyError from ai21.http_client import HttpClient from ai21.version import VERSION @@ -28,15 +26,15 @@ def __init__( self._api_key = api_key or self._env_config.api_key if self._api_key is None: - raise MissingApiKeyException() + raise MissingApiKeyError() self._api_host = api_host or self._env_config.api_host self._api_version = api_version or self._env_config.api_version self._headers = headers self._timeout_sec = timeout_sec or self._env_config.timeout_sec self._num_retries = num_retries or self._env_config.num_retries - self._organization = organization or self._env_config.organization - self._application = application or self._env_config.application + self._organization = organization + self._application = application self._via = via headers = self._build_headers(passed_headers=headers) @@ -87,7 +85,7 @@ def execute_http_request( method: str, url: str, params: Optional[Dict] = None, - files: Optional[Dict[str, io.TextIOWrapper]] = None, + files: Optional[Dict[str, BinaryIO]] = None, ): return self._http_client.execute_http_request(method=method, url=url, params=params, files=files) diff --git a/ai21/clients/bedrock/bedrock_session.py b/ai21/clients/bedrock/bedrock_session.py index 7d9f846c..82029da6 100644 --- a/ai21/clients/bedrock/bedrock_session.py +++ b/ai21/clients/bedrock/bedrock_session.py @@ -6,7 +6,7 @@ from botocore.exceptions import ClientError from ai21.logger import logger -from ai21.errors import AccessDenied, NotFound, AI21APITimeoutError +from ai21.errors import AccessDenied, NotFound, APITimeoutError from ai21.http_client import handle_non_success_response _ERROR_MSG_TEMPLATE = ( @@ -52,7 +52,7 @@ def _handle_client_error(self, client_exception: ClientError) -> None: raise NotFound(details=error_message) if status_code == 408: - raise AI21APITimeoutError(details=error_message) + raise APITimeoutError(details=error_message) if status_code == 424: error_message_template = re.compile(_ERROR_MSG_TEMPLATE) diff --git a/ai21/clients/studio/resources/studio_answer.py b/ai21/clients/studio/resources/studio_answer.py index ba79621e..5cd12fac 100644 --- a/ai21/clients/studio/resources/studio_answer.py +++ b/ai21/clients/studio/resources/studio_answer.py @@ -15,7 +15,7 @@ def create( mode: Optional[str] = None, **kwargs, ) -> AnswerResponse: - url = f"{self._client.get_base_url()}/{self._MODULE_NAME}" + url = f"{self._client.get_base_url()}/{self._module_name}" body = self._create_body(context=context, question=question, answer_length=answer_length, mode=mode) diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index 8fe1bca4..f1dab12b 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -39,6 +39,6 @@ def create( presence_penalty=presence_penalty, count_penalty=count_penalty, ) - url = f"{self._client.get_base_url()}/{model}/{self._MODULE_NAME}" + url = f"{self._client.get_base_url()}/{model}/{self._module_name}" response = self._post(url=url, body=body) return self._json_to_response(response) diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index 48f85fa7..10c1890f 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -18,7 +18,6 @@ def create( top_p: Optional[float] = 1, top_k_return: Optional[int] = 0, custom_model: Optional[str] = None, - experimental_mode: bool = False, stop_sequences: Optional[List[str]] = None, frequency_penalty: Optional[Dict[str, Any]] = None, presence_penalty: Optional[Dict[str, Any]] = None, @@ -26,9 +25,6 @@ def create( epoch: Optional[int] = None, **kwargs, ) -> CompletionsResponse: - if experimental_mode: - model = f"experimental/{model}" - url = f"{self._client.get_base_url()}/{model}" if custom_model is not None: @@ -45,7 +41,6 @@ def create( top_p=top_p, top_k_return=top_k_return, custom_model=custom_model, - experimental_mode=experimental_mode, stop_sequences=stop_sequences, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, diff --git a/ai21/clients/studio/resources/studio_dataset.py b/ai21/clients/studio/resources/studio_dataset.py index 05a07c52..8626d71b 100644 --- a/ai21/clients/studio/resources/studio_dataset.py +++ b/ai21/clients/studio/resources/studio_dataset.py @@ -6,7 +6,7 @@ class StudioDataset(StudioResource, Dataset): - def upload( + def create( self, file_path: str, dataset_name: str, diff --git a/ai21/clients/studio/resources/studio_improvements.py b/ai21/clients/studio/resources/studio_improvements.py index 50895e24..86287781 100644 --- a/ai21/clients/studio/resources/studio_improvements.py +++ b/ai21/clients/studio/resources/studio_improvements.py @@ -1,6 +1,6 @@ from typing import List -from ai21.errors import EmptyMandatoryListException +from ai21.errors import EmptyMandatoryListError from ai21.resources.bases.improvements_base import Improvements from ai21.resources.responses.improvement_response import ImprovementsResponse from ai21.resources.studio_resource import StudioResource @@ -9,7 +9,7 @@ class StudioImprovements(StudioResource, Improvements): def create(self, text: str, types: List[str], **kwargs) -> ImprovementsResponse: if len(types) == 0: - raise EmptyMandatoryListException("types") + raise EmptyMandatoryListError("types") url = f"{self._client.get_base_url()}/{self._module_name}" body = self._create_body(text=text, types=types) diff --git a/ai21/clients/studio/resources/studio_library.py b/ai21/clients/studio/resources/studio_library.py index 42daedbb..b8f96a3c 100644 --- a/ai21/clients/studio/resources/studio_library.py +++ b/ai21/clients/studio/resources/studio_library.py @@ -20,7 +20,7 @@ def __init__(self, client: AI21HTTPClient): class LibraryFiles(StudioResource): _module_name = "library/files" - def upload( + def create( self, file_path: str, *, diff --git a/ai21/errors.py b/ai21/errors.py index 33cf336b..4a0f8c92 100644 --- a/ai21/errors.py +++ b/ai21/errors.py @@ -31,7 +31,7 @@ def __init__(self, details: Optional[str] = None): super().__init__(404, details) -class AI21APITimeoutError(AI21APIError): +class APITimeoutError(AI21APIError): def __init__(self, details: Optional[str] = None): super().__init__(408, details) @@ -41,7 +41,7 @@ def __init__(self, details: Optional[str] = None): super().__init__(422, details) -class TooManyRequests(AI21APIError): +class TooManyRequestsError(AI21APIError): def __init__(self, details: Optional[str] = None): super().__init__(429, details) @@ -56,7 +56,7 @@ def __init__(self, details: Optional[str] = None): super().__init__(500, details) -class AI21ClientException(Exception): +class AI21Error(Exception): def __init__(self, message: str): self.message = message super().__init__(message) @@ -65,57 +65,14 @@ def __str__(self) -> str: return f"{type(self).__name__} {self.message}" -class MissingInputException(AI21ClientException): - def __init__(self, field_name: str, call_name: str): - message = f"{field_name} is required for the {call_name} call" - super().__init__(message) - - -class UnsupportedInputException(AI21ClientException): - def __init__(self, field_name: str, call_name: str): - message = f"{field_name} is unsupported for the {call_name} call" - super().__init__(message) - - -class UnsupportedDestinationException(AI21ClientException): - def __init__(self, destination_name: str, call_name: str): - message = f'Destination of type {destination_name} is unsupported for the "{call_name}" call' - super().__init__(message) - - -class OnlyOneInputException(AI21ClientException): - def __init__(self, field_name1: str, field_name2: str, call_name: str): - message = f"{field_name1} or {field_name2} is required for the {call_name} call, but not both" - super().__init__(message) - - -class WrongInputTypeException(AI21ClientException): - def __init__(self, key: str, expected_type: type, given_type: type): - message = f"Supplied {key} should be {expected_type}, but {given_type} was passed instead" - super().__init__(message) - - -class EmptyMandatoryListException(AI21ClientException): - def __init__(self, key: str): - message = f"Supplied {key} is empty. At least one element should be present in the list" - super().__init__(message) - - -class MissingApiKeyException(AI21ClientException): +class MissingApiKeyError(AI21Error): def __init__(self): message = "API key must be supplied either globally in the ai21 namespace, or to be provided in the call args" super().__init__(message) self.message = message -class NoSpecifiedRegionException(AI21ClientException): - def __init__(self): - message = "No AWS region provided" - super().__init__(message) - self.message = message - - -class ModelPackageDoesntExistException(AI21ClientException): +class ModelPackageDoesntExistError(AI21Error): def __init__(self, model_name: str, region: str, version: Optional[str] = None): message = f"model_name: {model_name} doesn't exist in region: {region}" @@ -124,3 +81,9 @@ def __init__(self, model_name: str, region: str, version: Optional[str] = None): super().__init__(message) self.message = message + + +class EmptyMandatoryListError(AI21Error): + def __init__(self, key: str): + message = f"Supplied {key} is empty. At least one element should be present in the list" + super().__init__(message) diff --git a/ai21/http_client.py b/ai21/http_client.py index 00692e07..0eeac1a1 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -1,20 +1,19 @@ -import io import json -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, BinaryIO import requests from requests.adapters import HTTPAdapter, Retry, RetryError -from ai21.logger import logger from ai21.errors import ( BadRequest, Unauthorized, UnprocessableEntity, - TooManyRequests, + TooManyRequestsError, AI21ServerError, ServiceUnavailable, AI21APIError, ) +from ai21.logger import logger DEFAULT_TIMEOUT_SEC = 300 DEFAULT_NUM_RETRIES = 0 @@ -32,7 +31,7 @@ def handle_non_success_response(status_code: int, response_text: str): if status_code == 422: raise UnprocessableEntity(details=response_text) if status_code == 429: - raise TooManyRequests(details=response_text) + raise TooManyRequestsError(details=response_text) if status_code == 500: raise AI21ServerError(details=response_text) if status_code == 503: @@ -74,7 +73,7 @@ def execute_http_request( method: str, url: str, params: Optional[Dict] = None, - files: Optional[Dict[str, io.TextIOWrapper]] = None, + files: Optional[Dict[str, BinaryIO]] = None, ): timeout = self._timeout_sec headers = self._headers diff --git a/ai21/resources/bases/answer_base.py b/ai21/resources/bases/answer_base.py index 0fbce8c0..4b11ff5c 100644 --- a/ai21/resources/bases/answer_base.py +++ b/ai21/resources/bases/answer_base.py @@ -5,7 +5,7 @@ class Answer(ABC): - _MODULE_NAME = "answer" + _module_name = "answer" def create( self, diff --git a/ai21/resources/bases/chat_base.py b/ai21/resources/bases/chat_base.py index e2a67c0d..f85270ee 100644 --- a/ai21/resources/bases/chat_base.py +++ b/ai21/resources/bases/chat_base.py @@ -11,7 +11,7 @@ class Message: class Chat(ABC): - _MODULE_NAME = "chat" + _module_name = "chat" @abstractmethod def create( diff --git a/ai21/resources/bases/completion_base.py b/ai21/resources/bases/completion_base.py index cb286df2..f549306a 100644 --- a/ai21/resources/bases/completion_base.py +++ b/ai21/resources/bases/completion_base.py @@ -20,7 +20,6 @@ def create( top_p=1, top_k_return=0, custom_model: Optional[str] = None, - experimental_mode: bool = False, stop_sequences: Optional[List[str]] = (), frequency_penalty: Optional[Dict[str, Any]] = {}, presence_penalty: Optional[Dict[str, Any]] = {}, @@ -44,7 +43,6 @@ def _create_body( top_p: Optional[int], top_k_return: Optional[int], custom_model: Optional[str], - experimental_mode: bool, stop_sequences: Optional[List[str]], frequency_penalty: Optional[Dict[str, Any]], presence_penalty: Optional[Dict[str, Any]], @@ -54,7 +52,6 @@ def _create_body( return { "model": model, "customModel": custom_model, - "experimentalModel": experimental_mode, "prompt": prompt, "maxTokens": max_tokens, "numResults": num_results, diff --git a/ai21/resources/bases/dataset_base.py b/ai21/resources/bases/dataset_base.py index dd53417c..2be49fc7 100644 --- a/ai21/resources/bases/dataset_base.py +++ b/ai21/resources/bases/dataset_base.py @@ -8,7 +8,7 @@ class Dataset(ABC): _module_name = "dataset" @abstractmethod - def upload( + def create( self, file_path: str, dataset_name: str, diff --git a/ai21/resources/studio_resource.py b/ai21/resources/studio_resource.py index 7752be91..8ece396e 100644 --- a/ai21/resources/studio_resource.py +++ b/ai21/resources/studio_resource.py @@ -1,8 +1,7 @@ from __future__ import annotations -import io from abc import ABC -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, BinaryIO from ai21.ai21_http_client import AI21HTTPClient @@ -15,7 +14,7 @@ def _post( self, url: str, body: Dict[str, Any], - files: Optional[Dict[str, io.TextIOWrapper]] = None, + files: Optional[Dict[str, BinaryIO]] = None, ) -> Dict[str, Any]: return self._client.execute_http_request( method="POST", diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py index b1387622..f51e1ae2 100644 --- a/ai21/services/sagemaker.py +++ b/ai21/services/sagemaker.py @@ -4,7 +4,7 @@ from ai21.clients.sagemaker.constants import ( SAGEMAKER_MODEL_PACKAGE_NAMES, ) -from ai21.errors import ModelPackageDoesntExistException +from ai21.errors import ModelPackageDoesntExistError _JUMPSTART_ENDPOINT = "jumpstart" _LIST_VERSIONS_ENDPOINT = f"{_JUMPSTART_ENDPOINT}/list_versions" @@ -33,7 +33,7 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE arn = response["arn"] if not arn: - raise ModelPackageDoesntExistException(model_name=model_name, region=region, version=version) + raise ModelPackageDoesntExistError(model_name=model_name, region=region, version=version) return arn @@ -61,4 +61,4 @@ def _create_ai21_http_client(cls) -> AI21HTTPClient: def _assert_model_package_exists(model_name, region): if model_name not in SAGEMAKER_MODEL_PACKAGE_NAMES: - raise ModelPackageDoesntExistException(model_name=model_name, region=region) + raise ModelPackageDoesntExistError(model_name=model_name, region=region) diff --git a/examples/studio/dataset.py b/examples/studio/dataset.py index b07d6565..87e587cc 100644 --- a/examples/studio/dataset.py +++ b/examples/studio/dataset.py @@ -3,7 +3,7 @@ file_path = "" client = AI21Client() -client.dataset.upload(file_path=file_path, dataset_name="my_new_ds_name") +client.dataset.create(file_path=file_path, dataset_name="my_new_ds_name") result = client.dataset.list() print(result) first_ds_id = result[0].id diff --git a/examples/studio/library.py b/examples/studio/library.py index e1377200..d693d697 100644 --- a/examples/studio/library.py +++ b/examples/studio/library.py @@ -24,7 +24,7 @@ def validate_file_deleted(): path = os.path.join(file_path, file_name) file_utils.create_file(file_path, file_name, content="test content" * 100) -file_id = client.library.files.upload( +file_id = client.library.files.create( file_path=path, path=file_path, labels=["label1", "label2"], diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index 1a921fef..6d94f2a7 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -98,7 +98,6 @@ def get_studio_completion(): "numResults": 1, "topP": 1, "customModel": None, - "experimentalModel": False, "topKReturn": 0, "stopSequences": [], "frequencyPenalty": None, diff --git a/tests/unittests/services/test_sagemaker.py b/tests/unittests/services/test_sagemaker.py index a92c23fe..dd36e1c9 100644 --- a/tests/unittests/services/test_sagemaker.py +++ b/tests/unittests/services/test_sagemaker.py @@ -1,6 +1,6 @@ import pytest -from ai21.errors import ModelPackageDoesntExistException +from ai21.errors import ModelPackageDoesntExistError from tests.unittests.services.sagemaker_stub import SageMakerStub _DUMMY_ARN = "some-model-package-id1" @@ -22,7 +22,7 @@ def test__get_model_package_arn__should_return_model_package_arn(self): def test__get_model_package_arn__when_no_arn__should_raise_error(self): SageMakerStub.ai21_http_client.execute_http_request.return_value = {"arn": []} - with pytest.raises(ModelPackageDoesntExistException): + with pytest.raises(ModelPackageDoesntExistError): SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") def test__list_model_package_versions__should_return_model_package_arn(self): @@ -36,9 +36,9 @@ def test__list_model_package_versions__should_return_model_package_arn(self): assert actual_model_package_arn == _DUMMY_VERSIONS def test__list_model_package_versions__when_model_package_not_available__should_raise_an_error(self): - with pytest.raises(ModelPackageDoesntExistException): + with pytest.raises(ModelPackageDoesntExistError): SageMakerStub.list_model_package_versions(model_name="openai", region="us-east-1") def test__get_model_package_arn__when_model_package_not_available__should_raise_an_error(self): - with pytest.raises(ModelPackageDoesntExistException): + with pytest.raises(ModelPackageDoesntExistError): SageMakerStub.get_model_package_arn(model_name="openai", region="us-east-1") From 916c7b40395eb2678cf4e66d49206fc94bbe9b73 Mon Sep 17 00:00:00 2001 From: github-actions Date: Wed, 27 Dec 2023 12:09:19 +0000 Subject: [PATCH 10/45] chore(release): v2.0.0-rc.5 [skip ci] --- CHANGELOG.md | 157 ++++++++++++++++++++++++++++++++++++++++++++++++ ai21/version.py | 2 +- pyproject.toml | 2 +- 3 files changed, 159 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d71390a..39446da2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,8 +2,165 @@ +## v2.0.0-rc.5 (2023-12-27) + +### Ci + +* ci: Remove python 3_7 support (#25) + +* ci: python 3.8 >= support + +* ci: unittests + +* test: dummy test ([`d166253`](https://github.com/AI21Labs/ai21-python/commit/d166253a10456725107e80af9c13f9b360a47b0f)) + +### Documentation + +* docs: Readme additions (#27) + +* fix: removed unnecessary url env var + +* fix: README.md CR ([`e4bb903`](https://github.com/AI21Labs/ai21-python/commit/e4bb903ef8775d9f755c6053e35176f407386342)) + +* docs: Readme migration (#24) + +* docs: instance text + +* docs: client instance explanation ([`53c53cb`](https://github.com/AI21Labs/ai21-python/commit/53c53cbd2f10556b54f41e313ec1cb4485657358)) + +* docs: README.md (#23) + +* ci: updated precommit hooks + +* docs: more readme updates + +* docs: removed extra lines + +* fix: rename + +* docs: readme + +* docs: full readme + +* docs: badges + +* ci: commitizen version + +* revert: via ([`bbb87d3`](https://github.com/AI21Labs/ai21-python/commit/bbb87d351c6f71ead0616c8bb90b1715285861a6)) + +### Fix + +* fix: Feedback fixes (#29) + +* test: get_tokenizer tests + +* fix: cases + +* test: Added some unittests to resources + +* fix: rename var + +* test: Added ai21 studio client tsts + +* fix: rename files + +* fix: Added types + +* test: added test to http + +* fix: removed unnecessary auth param + +* test: Added tests + +* test: Added sagemaker + +* test: Created a single session per instance + +* ci: removed unnecessary action + +* fix: errors + +* fix: error renames + +* fix: rename upload + +* fix: rename type + +* fix: rename variable + +* fix: removed experimental + +* test: fixed + +* test: Added some unittests to resources + +* test: Added ai21 studio client tsts + +* fix: rename files + +* fix: Added types + +* test: added test to http + +* fix: removed unnecessary auth param + +* test: Added tests + +* test: Added sagemaker + +* test: Created a single session per instance + +* fix: errors + +* fix: error renames + +* fix: rename upload + +* fix: rename type + +* fix: rename variable + +* fix: removed experimental + +* test: fixed ([`d6f73b5`](https://github.com/AI21Labs/ai21-python/commit/d6f73b5f3db3b234334b8e430d8a33633fb0247c)) + +### Test + +* test: Unittests for 2.0.0 (#28) + +* test: get_tokenizer tests + +* fix: cases + +* test: Added some unittests to resources + +* fix: rename var + +* test: Added ai21 studio client tsts + +* fix: rename files + +* fix: Added types + +* test: added test to http + +* fix: removed unnecessary auth param + +* test: Added tests + +* test: Added sagemaker + +* test: Created a single session per instance + +* ci: removed unnecessary action ([`c455b77`](https://github.com/AI21Labs/ai21-python/commit/c455b77ce20555c530f02107f1400287e928b371)) + + ## v2.0.0-rc.4 (2023-12-19) +### Chore + +* chore(release): v2.0.0-rc.4 [skip ci] ([`2f53ec9`](https://github.com/AI21Labs/ai21-python/commit/2f53ec9abebd580a33c382cd1d544d234c74dbbf)) + ### Ci * ci: Change main proj name (#22) diff --git a/ai21/version.py b/ai21/version.py index acd37e4a..ba278f7c 100644 --- a/ai21/version.py +++ b/ai21/version.py @@ -1 +1 @@ -VERSION = "2.0.0-rc.4" +VERSION = "2.0.0-rc.5" diff --git a/pyproject.toml b/pyproject.toml index 060d165b..9fec5885 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ exclude_lines = [ [tool.poetry] name = "ai21" -version = "2.0.0-rc.4" +version = "2.0.0-rc.5" description = "" authors = ["AI21 Labs"] readme = "README.md" From f84f86ab4992701ba8ff22b262317e1b336ea785 Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Sun, 31 Dec 2023 10:32:14 +0200 Subject: [PATCH 11/45] refactor: Add enums (#30) * refactor: answer enum * refactor: answer - mode enum * refactor: moved imports * refactor: Added enums to chat requests/response * refactor: Added enums to completion requests/response * fix: imports * refactor: Added embed types enum * refactor: Added correction type enum * refactor: Added improvement type enum * refactor: Added enums to paraphrase and library answer * refactor: Added enums to segmentation * refactor: Added enums to summary * refactor: Added enums to summary by segment * fix: test --- .../bedrock/resources/bedrock_completion.py | 15 +++---- .../sagemaker/resources/sagemaker_answer.py | 6 +-- .../resources/sagemaker_completion.py | 15 +++---- .../resources/sagemaker_paraphrase.py | 3 +- .../resources/sagemaker_summarize.py | 5 ++- .../clients/studio/resources/studio_answer.py | 6 +-- .../studio/resources/studio_completion.py | 9 +++-- .../studio/resources/studio_improvements.py | 3 +- .../studio/resources/studio_library.py | 6 ++- .../studio/resources/studio_paraphrase.py | 3 +- .../studio/resources/studio_segmentation.py | 4 +- .../studio/resources/studio_summarize.py | 3 +- .../resources/studio_summarize_by_segment.py | 5 ++- ai21/resources/__init__.py | 23 +++++++++++ ai21/resources/bases/answer_base.py | 6 ++- ai21/resources/bases/chat_base.py | 29 ++++++++------ ai21/resources/bases/completion_base.py | 19 ++++----- ai21/resources/bases/embed_base.py | 8 +++- ai21/resources/bases/improvements_base.py | 3 +- ai21/resources/bases/paraphrase_base.py | 3 +- ai21/resources/bases/segmentation_base.py | 3 +- ai21/resources/bases/summarize_base.py | 3 +- .../bases/summarize_by_segment_base.py | 3 +- ai21/resources/models/__init__.py | 0 ai21/resources/models/answer_length.py | 7 ++++ ai21/resources/models/document_type.py | 6 +++ ai21/resources/models/improvement_type.py | 9 +++++ ai21/resources/models/mode.py | 6 +++ .../resources/models/paraphrase_style_type.py | 9 +++++ ai21/resources/models/penalty.py | 14 +++++++ ai21/resources/models/role_type.py | 6 +++ ai21/resources/models/summary_method.py | 7 ++++ ai21/resources/responses/chat_response.py | 3 +- ai21/resources/responses/gec_response.py | 12 +++++- examples/bedrock/completion.py | 38 +++++++++++++++++- examples/studio/answer.py | 5 ++- examples/studio/chat.py | 37 ++++++++++-------- examples/studio/completion.py | 39 ++++++++++++++++++- examples/studio/custom_model.py | 2 + examples/studio/embed.py | 3 +- examples/studio/improvements.py | 3 +- examples/studio/paraphrase.py | 9 ++++- examples/studio/segmentation.py | 10 +++-- examples/studio/summarize.py | 14 ++++--- examples/studio/summarize_by_segment.py | 5 ++- .../clients/studio/resources/conftest.py | 19 ++++----- 46 files changed, 334 insertions(+), 112 deletions(-) create mode 100644 ai21/resources/models/__init__.py create mode 100644 ai21/resources/models/answer_length.py create mode 100644 ai21/resources/models/document_type.py create mode 100644 ai21/resources/models/improvement_type.py create mode 100644 ai21/resources/models/mode.py create mode 100644 ai21/resources/models/paraphrase_style_type.py create mode 100644 ai21/resources/models/penalty.py create mode 100644 ai21/resources/models/role_type.py create mode 100644 ai21/resources/models/summary_method.py diff --git a/ai21/clients/bedrock/resources/bedrock_completion.py b/ai21/clients/bedrock/resources/bedrock_completion.py index 285ef0d7..150cf381 100644 --- a/ai21/clients/bedrock/resources/bedrock_completion.py +++ b/ai21/clients/bedrock/resources/bedrock_completion.py @@ -1,5 +1,6 @@ -from typing import Optional, List, Any, Dict +from typing import Optional, List +from ai21.resources import Penalty from ai21.resources.bedrock_resource import BedrockResource from ai21.resources.responses.completion_response import CompletionsResponse @@ -17,9 +18,9 @@ def create( top_p: Optional[int] = 1, top_k_return: Optional[int] = 0, stop_sequences: Optional[List[str]] = None, - frequency_penalty: Optional[Dict[str, Any]] = None, - presence_penalty: Optional[Dict[str, Any]] = None, - count_penalty: Optional[Dict[str, Any]] = None, + frequency_penalty: Optional[Penalty] = None, + presence_penalty: Optional[Penalty] = None, + count_penalty: Optional[Penalty] = None, **kwargs, ) -> CompletionsResponse: body = { @@ -31,9 +32,9 @@ def create( "topP": top_p, "topKReturn": top_k_return, "stopSequences": stop_sequences or [], - "frequencyPenalty": frequency_penalty or {}, - "presencePenalty": presence_penalty or {}, - "countPenalty": count_penalty or {}, + "frequencyPenalty": None if frequency_penalty is None else frequency_penalty.to_dict(), + "presencePenalty": None if presence_penalty is None else presence_penalty.to_dict(), + "countPenalty": None if count_penalty is None else count_penalty.to_dict(), } raw_response = self._invoke(model_id=model_id, body=body) diff --git a/ai21/clients/sagemaker/resources/sagemaker_answer.py b/ai21/clients/sagemaker/resources/sagemaker_answer.py index 08112a51..1584abc9 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_answer.py +++ b/ai21/clients/sagemaker/resources/sagemaker_answer.py @@ -1,6 +1,6 @@ from typing import Optional -from ai21.resources.bases.answer_base import Answer +from ai21.resources.bases.answer_base import Answer, AnswerLength, Mode from ai21.resources.responses.answer_response import AnswerResponse from ai21.resources.sagemaker_resource import SageMakerResource @@ -11,8 +11,8 @@ def create( context: str, question: str, *, - answer_length: Optional[str] = None, - mode: Optional[str] = None, + answer_length: Optional[AnswerLength] = None, + mode: Optional[Mode] = None, **kwargs, ) -> AnswerResponse: body = self._create_body(context=context, question=question, answer_length=answer_length, mode=mode) diff --git a/ai21/clients/sagemaker/resources/sagemaker_completion.py b/ai21/clients/sagemaker/resources/sagemaker_completion.py index 53dcb6e7..687f9b94 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_completion.py +++ b/ai21/clients/sagemaker/resources/sagemaker_completion.py @@ -1,5 +1,6 @@ -from typing import Optional, Dict, Any, List +from typing import Optional, List +from ai21.resources import Penalty from ai21.resources.responses.completion_response import CompletionsResponse from ai21.resources.sagemaker_resource import SageMakerResource @@ -16,9 +17,9 @@ def create( top_p: Optional[int] = 1, top_k_return: Optional[int] = 0, stop_sequences: Optional[List[str]] = None, - frequency_penalty: Optional[Dict[str, Any]] = None, - presence_penalty: Optional[Dict[str, Any]] = None, - count_penalty: Optional[Dict[str, Any]] = None, + frequency_penalty: Optional[Penalty] = None, + presence_penalty: Optional[Penalty] = None, + count_penalty: Optional[Penalty] = None, **kwargs, ) -> CompletionsResponse: body = { @@ -30,9 +31,9 @@ def create( "topP": top_p, "topKReturn": top_k_return, "stopSequences": stop_sequences or [], - "frequencyPenalty": frequency_penalty, - "presencePenalty": presence_penalty, - "countPenalty": count_penalty, + "frequencyPenalty": None if frequency_penalty is None else frequency_penalty.to_dict(), + "presencePenalty": None if presence_penalty is None else presence_penalty.to_dict(), + "countPenalty": None if count_penalty is None else count_penalty.to_dict(), } raw_response = self._invoke(body) diff --git a/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py b/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py index 2ed20893..49d9ce5d 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py +++ b/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py @@ -1,6 +1,7 @@ from typing import Optional from ai21.resources.bases.paraphrase_base import Paraphrase +from ai21.resources.models.paraphrase_style_type import ParaphraseStyleType from ai21.resources.responses.paraphrase_response import ParaphraseResponse from ai21.resources.sagemaker_resource import SageMakerResource @@ -10,7 +11,7 @@ def create( self, text: str, *, - style: Optional[str] = None, + style: Optional[ParaphraseStyleType] = None, start_index: Optional[int] = 0, end_index: Optional[int] = None, **kwargs, diff --git a/ai21/clients/sagemaker/resources/sagemaker_summarize.py b/ai21/clients/sagemaker/resources/sagemaker_summarize.py index b137b693..c7db557b 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_summarize.py +++ b/ai21/clients/sagemaker/resources/sagemaker_summarize.py @@ -2,9 +2,10 @@ from typing import Optional +from ai21.resources.bases.summarize_base import Summarize +from ai21.resources.models.summary_method import SummaryMethod from ai21.resources.responses.summarize_response import SummarizeResponse from ai21.resources.sagemaker_resource import SageMakerResource -from ai21.resources.bases.summarize_base import Summarize class SageMakerSummarize(SageMakerResource, Summarize): @@ -14,7 +15,7 @@ def create( source_type: str, *, focus: Optional[str] = None, - summary_method: Optional[str] = None, + summary_method: Optional[SummaryMethod] = None, **kwargs, ) -> SummarizeResponse: body = self._create_body( diff --git a/ai21/clients/studio/resources/studio_answer.py b/ai21/clients/studio/resources/studio_answer.py index 5cd12fac..3962353c 100644 --- a/ai21/clients/studio/resources/studio_answer.py +++ b/ai21/clients/studio/resources/studio_answer.py @@ -1,6 +1,6 @@ from typing import Optional -from ai21.resources.bases.answer_base import Answer +from ai21.resources.bases.answer_base import Answer, AnswerLength, Mode from ai21.resources.responses.answer_response import AnswerResponse from ai21.resources.studio_resource import StudioResource @@ -11,8 +11,8 @@ def create( context: str, question: str, *, - answer_length: Optional[str] = None, - mode: Optional[str] = None, + answer_length: Optional[AnswerLength] = None, + mode: Optional[Mode] = None, **kwargs, ) -> AnswerResponse: url = f"{self._client.get_base_url()}/{self._module_name}" diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index 10c1890f..bcac84f8 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -1,5 +1,6 @@ -from typing import Optional, Dict, Any, List +from typing import Optional, List +from ai21.resources import Penalty from ai21.resources.bases.completion_base import Completion from ai21.resources.responses.completion_response import CompletionsResponse from ai21.resources.studio_resource import StudioResource @@ -19,9 +20,9 @@ def create( top_k_return: Optional[int] = 0, custom_model: Optional[str] = None, stop_sequences: Optional[List[str]] = None, - frequency_penalty: Optional[Dict[str, Any]] = None, - presence_penalty: Optional[Dict[str, Any]] = None, - count_penalty: Optional[Dict[str, Any]] = None, + frequency_penalty: Optional[Penalty] = None, + presence_penalty: Optional[Penalty] = None, + count_penalty: Optional[Penalty] = None, epoch: Optional[int] = None, **kwargs, ) -> CompletionsResponse: diff --git a/ai21/clients/studio/resources/studio_improvements.py b/ai21/clients/studio/resources/studio_improvements.py index 86287781..2e17cfd9 100644 --- a/ai21/clients/studio/resources/studio_improvements.py +++ b/ai21/clients/studio/resources/studio_improvements.py @@ -2,12 +2,13 @@ from ai21.errors import EmptyMandatoryListError from ai21.resources.bases.improvements_base import Improvements +from ai21.resources.models.improvement_type import ImprovementType from ai21.resources.responses.improvement_response import ImprovementsResponse from ai21.resources.studio_resource import StudioResource class StudioImprovements(StudioResource, Improvements): - def create(self, text: str, types: List[str], **kwargs) -> ImprovementsResponse: + def create(self, text: str, types: List[ImprovementType], **kwargs) -> ImprovementsResponse: if len(types) == 0: raise EmptyMandatoryListError("types") diff --git a/ai21/clients/studio/resources/studio_library.py b/ai21/clients/studio/resources/studio_library.py index b8f96a3c..1962cd58 100644 --- a/ai21/clients/studio/resources/studio_library.py +++ b/ai21/clients/studio/resources/studio_library.py @@ -1,6 +1,8 @@ from typing import Optional, List from ai21.ai21_http_client import AI21HTTPClient +from ai21.resources.models.answer_length import AnswerLength +from ai21.resources.models.mode import Mode from ai21.resources.responses.file_response import FileResponse from ai21.resources.responses.library_answer_response import LibraryAnswerResponse from ai21.resources.responses.library_search_response import LibrarySearchResponse @@ -109,8 +111,8 @@ def create( path: Optional[str] = None, field_ids: Optional[List[str]] = None, labels: Optional[List[str]] = None, - answer_length: Optional[str] = None, - mode: Optional[str] = None, + answer_length: Optional[AnswerLength] = None, + mode: Optional[Mode] = None, **kwargs, ) -> LibraryAnswerResponse: url = f"{self._client.get_base_url()}/{self._module_name}" diff --git a/ai21/clients/studio/resources/studio_paraphrase.py b/ai21/clients/studio/resources/studio_paraphrase.py index a3c602f2..1dbf06fb 100644 --- a/ai21/clients/studio/resources/studio_paraphrase.py +++ b/ai21/clients/studio/resources/studio_paraphrase.py @@ -1,6 +1,7 @@ from typing import Optional from ai21.resources.bases.paraphrase_base import Paraphrase +from ai21.resources.models.paraphrase_style_type import ParaphraseStyleType from ai21.resources.responses.paraphrase_response import ParaphraseResponse from ai21.resources.studio_resource import StudioResource @@ -10,7 +11,7 @@ def create( self, text: str, *, - style: Optional[str] = None, + style: Optional[ParaphraseStyleType] = None, start_index: Optional[int] = None, end_index: Optional[int] = None, **kwargs, diff --git a/ai21/clients/studio/resources/studio_segmentation.py b/ai21/clients/studio/resources/studio_segmentation.py index 144789ba..0586561f 100644 --- a/ai21/clients/studio/resources/studio_segmentation.py +++ b/ai21/clients/studio/resources/studio_segmentation.py @@ -1,9 +1,11 @@ from ai21.resources.bases.segmentation_base import Segmentation +from ai21.resources.models.document_type import DocumentType +from ai21.resources.responses.segmentation_response import SegmentationResponse from ai21.resources.studio_resource import StudioResource class StudioSegmentation(StudioResource, Segmentation): - def create(self, source: str, source_type: str, **kwargs): + def create(self, source: str, source_type: DocumentType, **kwargs) -> SegmentationResponse: body = self._create_body(source=source, source_type=source_type) url = f"{self._client.get_base_url()}/{self._module_name}" raw_response = self._post(url=url, body=body) diff --git a/ai21/clients/studio/resources/studio_summarize.py b/ai21/clients/studio/resources/studio_summarize.py index 1c7dc675..1b9c44d3 100644 --- a/ai21/clients/studio/resources/studio_summarize.py +++ b/ai21/clients/studio/resources/studio_summarize.py @@ -1,6 +1,7 @@ from typing import Optional from ai21.resources.bases.summarize_base import Summarize +from ai21.resources.models.summary_method import SummaryMethod from ai21.resources.responses.summarize_response import SummarizeResponse from ai21.resources.studio_resource import StudioResource @@ -14,7 +15,7 @@ def create( source_type: str, *, focus: Optional[str] = None, - summary_method: Optional[str] = None, + summary_method: Optional[SummaryMethod] = None, **kwargs, ) -> SummarizeResponse: # Make a summarize request to the AI21 API. Returns the response either as a string or a AI21Summarize object. diff --git a/ai21/clients/studio/resources/studio_summarize_by_segment.py b/ai21/clients/studio/resources/studio_summarize_by_segment.py index 2ba253e9..7644a5f4 100644 --- a/ai21/clients/studio/resources/studio_summarize_by_segment.py +++ b/ai21/clients/studio/resources/studio_summarize_by_segment.py @@ -1,15 +1,16 @@ from typing import Optional +from ai21.resources.bases.summarize_by_segment_base import SummarizeBySegment +from ai21.resources.models.document_type import DocumentType from ai21.resources.responses.summarize_by_segment_response import ( SummarizeBySegmentResponse, ) from ai21.resources.studio_resource import StudioResource -from ai21.resources.bases.summarize_by_segment_base import SummarizeBySegment class StudioSummarizeBySegment(StudioResource, SummarizeBySegment): def create( - self, source: str, source_type: str, *, focus: Optional[str] = None, **kwargs + self, source: str, source_type: DocumentType, *, focus: Optional[str] = None, **kwargs ) -> SummarizeBySegmentResponse: body = self._create_body( source=source, diff --git a/ai21/resources/__init__.py b/ai21/resources/__init__.py index e69de29b..626c7221 100644 --- a/ai21/resources/__init__.py +++ b/ai21/resources/__init__.py @@ -0,0 +1,23 @@ +from ai21.resources.bases.chat_base import Message +from ai21.resources.bases.embed_base import EmbedType +from ai21.resources.models.answer_length import AnswerLength +from ai21.resources.models.document_type import DocumentType +from ai21.resources.models.improvement_type import ImprovementType +from ai21.resources.models.mode import Mode +from ai21.resources.models.paraphrase_style_type import ParaphraseStyleType +from ai21.resources.models.penalty import Penalty +from ai21.resources.models.role_type import RoleType +from ai21.resources.models.summary_method import SummaryMethod + +__all__ = [ + "AnswerLength", + "Mode", + "Message", + "RoleType", + "Penalty", + "EmbedType", + "ImprovementType", + "ParaphraseStyleType", + "DocumentType", + "SummaryMethod", +] diff --git a/ai21/resources/bases/answer_base.py b/ai21/resources/bases/answer_base.py index 4b11ff5c..96663084 100644 --- a/ai21/resources/bases/answer_base.py +++ b/ai21/resources/bases/answer_base.py @@ -1,6 +1,8 @@ from abc import ABC from typing import Optional, Any, Dict +from ai21.resources.models.answer_length import AnswerLength +from ai21.resources.models.mode import Mode from ai21.resources.responses.answer_response import AnswerResponse @@ -12,8 +14,8 @@ def create( context: str, question: str, *, - answer_length: Optional[str] = None, - mode: Optional[str] = None, + answer_length: Optional[AnswerLength] = None, + mode: Optional[Mode] = None, **kwargs, ) -> AnswerResponse: pass diff --git a/ai21/resources/bases/chat_base.py b/ai21/resources/bases/chat_base.py index f85270ee..54854bb7 100644 --- a/ai21/resources/bases/chat_base.py +++ b/ai21/resources/bases/chat_base.py @@ -1,11 +1,16 @@ from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import List, Any, Dict, Optional +from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin +from ai21.resources.models.penalty import Penalty +from ai21.resources.models.role_type import RoleType from ai21.resources.responses.chat_response import ChatResponse -class Message: - role: str +@dataclass +class Message(AI21BaseModelMixin): + role: RoleType text: str name: Optional[str] @@ -27,9 +32,9 @@ def create( top_p: Optional[float] = 1.0, top_k_returns: Optional[int] = 0, stop_sequences: Optional[List[str]] = None, - frequency_penalty: Optional[Dict[str, Any]] = None, - presence_penalty: Optional[Dict[str, Any]] = None, - count_penalty: Optional[Dict[str, Any]] = None, + frequency_penalty: Optional[Penalty] = None, + presence_penalty: Optional[Penalty] = None, + count_penalty: Optional[Penalty] = None, **kwargs, ) -> ChatResponse: pass @@ -49,14 +54,14 @@ def _create_body( top_p: Optional[float] = 1.0, top_k_returns: Optional[int] = 0, stop_sequences: Optional[List[str]] = None, - frequency_penalty: Optional[Dict[str, Any]] = None, - presence_penalty: Optional[Dict[str, Any]] = None, - count_penalty: Optional[Dict[str, Any]] = None, + frequency_penalty: Optional[Penalty] = None, + presence_penalty: Optional[Penalty] = None, + count_penalty: Optional[Penalty] = None, ) -> Dict[str, Any]: return { "model": model, "system": system, - "messages": messages, + "messages": [message.to_dict() for message in messages], "temperature": temperature, "maxTokens": max_tokens, "minTokens": min_tokens, @@ -64,7 +69,7 @@ def _create_body( "topP": top_p, "topKReturn": top_k_returns, "stopSequences": stop_sequences, - "frequencyPenalty": frequency_penalty, - "presencePenalty": presence_penalty, - "countPenalty": count_penalty, + "frequencyPenalty": None if frequency_penalty is None else frequency_penalty.to_dict(), + "presencePenalty": None if presence_penalty is None else presence_penalty.to_dict(), + "countPenalty": None if count_penalty is None else count_penalty.to_dict(), } diff --git a/ai21/resources/bases/completion_base.py b/ai21/resources/bases/completion_base.py index f549306a..df9f2352 100644 --- a/ai21/resources/bases/completion_base.py +++ b/ai21/resources/bases/completion_base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional, List, Dict, Any +from ai21.resources import Penalty from ai21.resources.responses.completion_response import CompletionsResponse @@ -21,9 +22,9 @@ def create( top_k_return=0, custom_model: Optional[str] = None, stop_sequences: Optional[List[str]] = (), - frequency_penalty: Optional[Dict[str, Any]] = {}, - presence_penalty: Optional[Dict[str, Any]] = {}, - count_penalty: Optional[Dict[str, Any]] = {}, + frequency_penalty: Optional[Penalty] = None, + presence_penalty: Optional[Penalty] = None, + count_penalty: Optional[Penalty] = None, epoch: Optional[int] = None, **kwargs, ) -> CompletionsResponse: @@ -44,9 +45,9 @@ def _create_body( top_k_return: Optional[int], custom_model: Optional[str], stop_sequences: Optional[List[str]], - frequency_penalty: Optional[Dict[str, Any]], - presence_penalty: Optional[Dict[str, Any]], - count_penalty: Optional[Dict[str, Any]], + frequency_penalty: Optional[Penalty], + presence_penalty: Optional[Penalty], + count_penalty: Optional[Penalty], epoch: Optional[int], ): return { @@ -60,8 +61,8 @@ def _create_body( "topP": top_p, "topKReturn": top_k_return, "stopSequences": stop_sequences or [], - "frequencyPenalty": frequency_penalty, - "presencePenalty": presence_penalty, - "countPenalty": count_penalty, + "frequencyPenalty": None if frequency_penalty is None else frequency_penalty.to_dict(), + "presencePenalty": None if presence_penalty is None else presence_penalty.to_dict(), + "countPenalty": None if count_penalty is None else count_penalty.to_dict(), "epoch": epoch, } diff --git a/ai21/resources/bases/embed_base.py b/ai21/resources/bases/embed_base.py index 8ef93bb5..ee07187b 100644 --- a/ai21/resources/bases/embed_base.py +++ b/ai21/resources/bases/embed_base.py @@ -1,14 +1,20 @@ from abc import ABC, abstractmethod +from enum import Enum from typing import List, Any, Dict, Optional from ai21.resources.responses.embed_response import EmbedResponse +class EmbedType(str, Enum): + QUERY = "query" + SEGMENT = "segment" + + class Embed(ABC): _module_name = "embed" @abstractmethod - def create(self, texts: List[str], *, type: Optional[str] = None, **kwargs) -> EmbedResponse: + def create(self, texts: List[str], *, type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse: pass def _json_to_response(self, json: Dict[str, Any]) -> EmbedResponse: diff --git a/ai21/resources/bases/improvements_base.py b/ai21/resources/bases/improvements_base.py index 75b44ccf..69ac49e7 100644 --- a/ai21/resources/bases/improvements_base.py +++ b/ai21/resources/bases/improvements_base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List +from ai21.resources.models.improvement_type import ImprovementType from ai21.resources.responses.improvement_response import ImprovementsResponse @@ -8,7 +9,7 @@ class Improvements(ABC): _module_name = "improvements" @abstractmethod - def create(self, text: str, types: List[str], **kwargs) -> ImprovementsResponse: + def create(self, text: str, types: List[ImprovementType], **kwargs) -> ImprovementsResponse: pass def _json_to_response(self, json: Dict[str, Any]) -> ImprovementsResponse: diff --git a/ai21/resources/bases/paraphrase_base.py b/ai21/resources/bases/paraphrase_base.py index 34b76724..aa2e2e7d 100644 --- a/ai21/resources/bases/paraphrase_base.py +++ b/ai21/resources/bases/paraphrase_base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional, Any, Dict +from ai21.resources.models.paraphrase_style_type import ParaphraseStyleType from ai21.resources.responses.paraphrase_response import ParaphraseResponse @@ -12,7 +13,7 @@ def create( self, text: str, *, - style: Optional[str] = None, + style: Optional[ParaphraseStyleType] = None, start_index: Optional[int] = 0, end_index: Optional[int] = None, **kwargs, diff --git a/ai21/resources/bases/segmentation_base.py b/ai21/resources/bases/segmentation_base.py index 40f1a409..f20cecb3 100644 --- a/ai21/resources/bases/segmentation_base.py +++ b/ai21/resources/bases/segmentation_base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict +from ai21.resources.models.document_type import DocumentType from ai21.resources.responses.segmentation_response import SegmentationResponse @@ -8,7 +9,7 @@ class Segmentation(ABC): _module_name = "segmentation" @abstractmethod - def create(self, source: str, source_type: str): + def create(self, source: str, source_type: DocumentType, **kwargs) -> SegmentationResponse: pass def _json_to_response(self, json: Dict[str, Any]) -> SegmentationResponse: diff --git a/ai21/resources/bases/summarize_base.py b/ai21/resources/bases/summarize_base.py index 9190bded..435549a5 100644 --- a/ai21/resources/bases/summarize_base.py +++ b/ai21/resources/bases/summarize_base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional, Any, Dict +from ai21.resources.models.summary_method import SummaryMethod from ai21.resources.responses.summarize_response import SummarizeResponse @@ -12,7 +13,7 @@ def create( source_type: str, *, focus: Optional[str] = None, - summary_method: Optional[str] = None, + summary_method: Optional[SummaryMethod] = None, **kwargs, ) -> SummarizeResponse: pass diff --git a/ai21/resources/bases/summarize_by_segment_base.py b/ai21/resources/bases/summarize_by_segment_base.py index 4069664e..ac30d1ae 100644 --- a/ai21/resources/bases/summarize_by_segment_base.py +++ b/ai21/resources/bases/summarize_by_segment_base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional, Any, Dict +from ai21.resources.models.document_type import DocumentType from ai21.resources.responses.summarize_by_segment_response import ( SummarizeBySegmentResponse, ) @@ -13,7 +14,7 @@ class SummarizeBySegment(ABC): def create( self, source: str, - source_type: str, + source_type: DocumentType, *, focus: Optional[str] = None, **kwargs, diff --git a/ai21/resources/models/__init__.py b/ai21/resources/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ai21/resources/models/answer_length.py b/ai21/resources/models/answer_length.py new file mode 100644 index 00000000..8adc82fd --- /dev/null +++ b/ai21/resources/models/answer_length.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class AnswerLength(str, Enum): + SHORT = "short" + MEDIUM = "medium" + LONG = "long" diff --git a/ai21/resources/models/document_type.py b/ai21/resources/models/document_type.py new file mode 100644 index 00000000..de4be01a --- /dev/null +++ b/ai21/resources/models/document_type.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class DocumentType(str, Enum): + URL = "URL" + TEXT = "TEXT" diff --git a/ai21/resources/models/improvement_type.py b/ai21/resources/models/improvement_type.py new file mode 100644 index 00000000..0774c4b0 --- /dev/null +++ b/ai21/resources/models/improvement_type.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class ImprovementType(str, Enum): + FLUENCY = "fluency" + VOCABULARY_SPECIFICITY = "vocabulary/specificity" + VOCABULARY_VARIETY = "vocabulary/variety" + CLARITY_SHORT_SENTENCES = "clarity/short-sentences" + CLARITY_CONCISENESS = "clarity/conciseness" diff --git a/ai21/resources/models/mode.py b/ai21/resources/models/mode.py new file mode 100644 index 00000000..e3e49347 --- /dev/null +++ b/ai21/resources/models/mode.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class Mode(str, Enum): + FLEXIBLE = "flexible" + STRICT = "strict" diff --git a/ai21/resources/models/paraphrase_style_type.py b/ai21/resources/models/paraphrase_style_type.py new file mode 100644 index 00000000..b7d7bd54 --- /dev/null +++ b/ai21/resources/models/paraphrase_style_type.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class ParaphraseStyleType(str, Enum): + LONG = "long" + SHORT = "short" + FORMAL = "formal" + CASUAL = "casual" + GENERAL = "general" diff --git a/ai21/resources/models/penalty.py b/ai21/resources/models/penalty.py new file mode 100644 index 00000000..74c69ff2 --- /dev/null +++ b/ai21/resources/models/penalty.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from typing import Optional + +from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin + + +@dataclass +class Penalty(AI21BaseModelMixin): + scale: float + apply_to_whitespaces: Optional[bool] = None + apply_to_punctuation: Optional[bool] = None + apply_to_numbers: Optional[bool] = None + apply_to_stopwords: Optional[bool] = None + apply_to_emojis: Optional[bool] = None diff --git a/ai21/resources/models/role_type.py b/ai21/resources/models/role_type.py new file mode 100644 index 00000000..a1630a23 --- /dev/null +++ b/ai21/resources/models/role_type.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class RoleType(str, Enum): + USER = "user" + ASSISTANT = "assistant" diff --git a/ai21/resources/models/summary_method.py b/ai21/resources/models/summary_method.py new file mode 100644 index 00000000..b4b05554 --- /dev/null +++ b/ai21/resources/models/summary_method.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class SummaryMethod(str, Enum): + SEGMENTS = "segments" + GUIDED = "guided" + FULL_DOCUMENT = "fullDocument" diff --git a/ai21/resources/responses/chat_response.py b/ai21/resources/responses/chat_response.py index f8398dfa..81ce3062 100644 --- a/ai21/resources/responses/chat_response.py +++ b/ai21/resources/responses/chat_response.py @@ -2,6 +2,7 @@ from typing import Optional, List from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin +from ai21.resources.models.role_type import RoleType @dataclass @@ -14,7 +15,7 @@ class FinishReason(AI21BaseModelMixin): @dataclass class ChatOutput(AI21BaseModelMixin): text: str - role: str + role: RoleType finish_reason: FinishReason diff --git a/ai21/resources/responses/gec_response.py b/ai21/resources/responses/gec_response.py index 37109deb..4d5cc0a2 100644 --- a/ai21/resources/responses/gec_response.py +++ b/ai21/resources/responses/gec_response.py @@ -1,16 +1,26 @@ from dataclasses import dataclass +from enum import Enum from typing import List from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin +class CorrectionType(str, Enum): + GRAMMAR = "Grammar" + MISSING_WORD = "Missing Word" + PUNCTUATION = "Punctuation" + SPELLING = "Spelling" + WORD_REPETITION = "Word Repetition" + WRONG_WORD = "Wrong Word" + + @dataclass class Correction(AI21BaseModelMixin): suggestion: str start_index: int end_index: int original_text: str - correction_type: str + correction_type: CorrectionType @dataclass diff --git a/examples/bedrock/completion.py b/examples/bedrock/completion.py index 9b71a423..8002e258 100644 --- a/examples/bedrock/completion.py +++ b/examples/bedrock/completion.py @@ -1,4 +1,5 @@ from ai21 import AI21BedrockClient, BedrockModelID +from ai21.resources import Penalty # Bedrock is currently supported only in us-east-1 region. # Either set your profile's region to us-east-1 or uncomment next line @@ -38,7 +39,42 @@ "User: Hi, I have a question for you" ) -response = AI21BedrockClient().completion.create(prompt=prompt, max_tokens=1000, model_id=BedrockModelID.J2_MID_V1) +response = AI21BedrockClient().completion.create( + prompt=prompt, + max_tokens=1000, + model_id=BedrockModelID.J2_MID_V1, + temperature=0, + top_p=1, + top_k_return=0, + stop_sequences=["##"], + num_results=1, + custom_model=None, + epoch=1, + count_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + frequency_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + presence_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), +) print(response.completions[0].data.text) print(response.prompt.tokens[0]["textRange"]["start"]) diff --git a/examples/studio/answer.py b/examples/studio/answer.py index b726709e..4ba86649 100644 --- a/examples/studio/answer.py +++ b/examples/studio/answer.py @@ -1,4 +1,5 @@ from ai21 import AI21Client +from ai21.resources import Mode, AnswerLength client = AI21Client() @@ -9,7 +10,7 @@ "ruled by the counts of Holland. By the 17th century, the province of Holland had risen to become a maritime and " "economic power, dominating the other provinces of the newly independent Dutch Republic.", question="When did Holland become an economic power?", - answer_length="long", - mode="flexible", + answer_length=AnswerLength.LONG, + mode=Mode.FLEXIBLE, ) print(response) diff --git a/examples/studio/chat.py b/examples/studio/chat.py index 8a229e56..da8efa36 100644 --- a/examples/studio/chat.py +++ b/examples/studio/chat.py @@ -1,25 +1,30 @@ from ai21 import AI21Client +from ai21.resources import Message, RoleType, Penalty system = "You're a support engineer in a SaaS company" messages = [ - { - "text": "Hello, I need help with a signup process.", - "role": "user", - "name": "Alice", - }, - { - "text": "Hi Alice, I can help you with that. What seems to be the problem?", - "role": "assistant", - "name": "Bob", - }, - { - "text": "I am having trouble signing up for your product with my Google account.", - "role": "user", - "name": "Alice", - }, + Message(text="Hello, I need help with a signup process.", role=RoleType.USER, name="Alice"), + Message( + text="Hi Alice, I can help you with that. What seems to be the problem?", role=RoleType.ASSISTANT, name="Bob" + ), + Message( + text="I am having trouble signing up for your product with my Google account.", role=RoleType.USER, name="Alice" + ), ] client = AI21Client() -response = client.chat.create(system=system, messages=messages, model="j2-ultra") +response = client.chat.create( + system=system, + messages=messages, + model="j2-ultra", + count_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), +) print(response) diff --git a/examples/studio/completion.py b/examples/studio/completion.py index 5b8ffddb..333f31a7 100644 --- a/examples/studio/completion.py +++ b/examples/studio/completion.py @@ -1,5 +1,5 @@ from ai21 import AI21Client - +from ai21.resources import Penalty prompt = ( "The following is a conversation between a user of an eCommerce store and a user operation" @@ -33,7 +33,42 @@ ) client = AI21Client() -response = client.completion.create(prompt=prompt, max_tokens=2, model="j2-light", temperature=0) +response = client.completion.create( + prompt=prompt, + max_tokens=2, + model="j2-light", + temperature=0, + top_p=1, + top_k_return=0, + stop_sequences=["##"], + num_results=1, + custom_model=None, + epoch=1, + count_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + frequency_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + presence_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), +) print(response) print(response.completions[0].data.text) diff --git a/examples/studio/custom_model.py b/examples/studio/custom_model.py index 046881af..d1f7b7e1 100644 --- a/examples/studio/custom_model.py +++ b/examples/studio/custom_model.py @@ -1,3 +1,5 @@ +import uuid + from ai21 import AI21Client diff --git a/examples/studio/embed.py b/examples/studio/embed.py index 7829c7d9..57166d21 100644 --- a/examples/studio/embed.py +++ b/examples/studio/embed.py @@ -1,8 +1,9 @@ from ai21 import AI21Client +from ai21.resources import EmbedType client = AI21Client() response = client.embed.create( texts=["Holland is a geographical region[2] and former province on the western coast of the Netherlands."], - type="segment", + type=EmbedType.SEGMENT, ) print("embed: ", response.results[0].embedding) diff --git a/examples/studio/improvements.py b/examples/studio/improvements.py index 54f2b930..d1919591 100644 --- a/examples/studio/improvements.py +++ b/examples/studio/improvements.py @@ -1,10 +1,11 @@ from ai21 import AI21Client +from ai21.resources import ImprovementType client = AI21Client() response = client.improvements.create( text="Affiliated with the profession of project management," " I have ameliorated myself with a different set of hard skills as well as soft skills.", - types=["fluency"], + types=[ImprovementType.FLUENCY], ) print(response.improvements[0].original_text) diff --git a/examples/studio/paraphrase.py b/examples/studio/paraphrase.py index a80871fa..f7f643a6 100644 --- a/examples/studio/paraphrase.py +++ b/examples/studio/paraphrase.py @@ -1,8 +1,13 @@ from ai21 import AI21Client - +from ai21.resources import ParaphraseStyleType client = AI21Client() -response = client.paraphrase.create(text="The cat (Felis catus) is a domestic species of small carnivorous mammal") +response = client.paraphrase.create( + text="The cat (Felis catus) is a domestic species of small carnivorous mammal", + style=ParaphraseStyleType.GENERAL, + start_index=0, + end_index=20, +) print(response.suggestions[0].text) print(response.suggestions[1].text) diff --git a/examples/studio/segmentation.py b/examples/studio/segmentation.py index 79658123..93d3a7be 100644 --- a/examples/studio/segmentation.py +++ b/examples/studio/segmentation.py @@ -1,11 +1,15 @@ from ai21 import AI21Client - +from ai21.resources import DocumentType client = AI21Client() response = client.segmentation.create( - source="Holland is a geographical region[2] and former province on the western coast of the Netherlands.[2] From the 10th to the 16th century, Holland proper was a unified political region within the Holy Roman Empire as a county ruled by the counts of Holland. By the 17th century, the province of Holland had risen to become a maritime and economic power, dominating the other provinces of the newly independent Dutch Republic.", - source_type="TEXT", + source="Holland is a geographical region[2] and former province on the western " + "coast of the Netherlands.[2] From the 10th to the 16th century, Holland proper was " + "a unified political region within the Holy Roman Empire as a county ruled by the counts of" + " Holland. By the 17th century, the province of Holland had risen to become a maritime and economic power," + " dominating the other provinces of the newly independent Dutch Republic.", + source_type=DocumentType.TEXT, ) print(response.segments[0].segment_text) diff --git a/examples/studio/summarize.py b/examples/studio/summarize.py index d07f1057..73dc9ece 100644 --- a/examples/studio/summarize.py +++ b/examples/studio/summarize.py @@ -1,11 +1,15 @@ from ai21 import AI21Client +from ai21.resources import DocumentType, SummaryMethod client = AI21Client() response = client.summarize.create( - source="Holland is a geographical region[2] and former province on the western coast of the Netherlands.[2] From the 10th to the 16th century, " - "Holland proper was a unified political region within the Holy Roman Empire as a county ruled by the counts of Holland. By the 17th century, " - "the province of Holland had risen to become a maritime and economic power, dominating the other provinces of the newly independent Dutch " - "Republic.", - source_type="TEXT", + source="Holland is a geographical region[2] and former province on the western " + "coast of the Netherlands.[2] From the 10th to the 16th century, Holland proper was " + "a unified political region within the Holy Roman Empire as a county ruled by the counts of" + " Holland. By the 17th century, the province of Holland had risen to become a maritime and economic power," + " dominating the other provinces of the newly independent Dutch Republic.", + source_type=DocumentType.TEXT, + summary_length=SummaryMethod.SEGMENTS, + focus="Holland", ) print(response.summary) diff --git a/examples/studio/summarize_by_segment.py b/examples/studio/summarize_by_segment.py index 239af4e7..4ab1d82f 100644 --- a/examples/studio/summarize_by_segment.py +++ b/examples/studio/summarize_by_segment.py @@ -1,5 +1,5 @@ from ai21 import AI21Client - +from ai21.resources import DocumentType client = AI21Client() response = client.summarize_by_segment.create( @@ -9,6 +9,7 @@ "county ruled by the counts of Holland. By the 17th century, " "the province of Holland had risen to become a maritime and economic power," " dominating the other provinces of the newly independent Dutch Republic.", - source_type="TEXT", + source_type=DocumentType.TEXT, + focus="Holland", ) print(response) diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index 6d94f2a7..8c50a082 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -6,6 +6,7 @@ from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_chat import StudioChat from ai21.clients.studio.resources.studio_completion import StudioCompletion +from ai21.resources import Message, RoleType from ai21.resources.responses.chat_response import ChatOutput, FinishReason from ai21.resources.responses.completion_response import Prompt, Completion, CompletionData, CompletionFinishReason @@ -36,16 +37,12 @@ def get_studio_answer(): def get_studio_chat(): _DUMMY_MODEL = "dummy-chat-model" _DUMMY_MESSAGES = [ - { - "text": "Hello, I need help with a signup process.", - "role": "user", - "name": "Alice", - }, - { - "text": "Hi Alice, I can help you with that. What seems to be the problem?", - "role": "assistant", - "name": "Bob", - }, + Message(text="Hello, I need help with a signup process.", role=RoleType.USER, name="Alice"), + Message( + text="Hi Alice, I can help you with that. What seems to be the problem?", + role=RoleType.ASSISTANT, + name="Bob", + ), ] _DUMMY_SYSTEM = "You're a support engineer in a SaaS company" @@ -56,7 +53,7 @@ def get_studio_chat(): { "model": _DUMMY_MODEL, "system": _DUMMY_SYSTEM, - "messages": _DUMMY_MESSAGES, + "messages": [message.to_dict() for message in _DUMMY_MESSAGES], "temperature": 0.7, "maxTokens": 300, "minTokens": 0, From 92c3f5de0b35510a3271a389160ab06cc064c7ea Mon Sep 17 00:00:00 2001 From: etang Date: Tue, 2 Jan 2024 13:16:21 +0200 Subject: [PATCH 12/45] fix: bump version From 1ed334b466acca6a8e03c49a387e5918b7a69d45 Mon Sep 17 00:00:00 2001 From: github-actions Date: Tue, 2 Jan 2024 11:18:20 +0000 Subject: [PATCH 13/45] chore(release): v2.0.0-rc.6 [skip ci] --- CHANGELOG.md | 43 +++++++++++++++++++++++++++++++++++++++++++ ai21/version.py | 2 +- pyproject.toml | 2 +- 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 39446da2..c591603a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,8 +2,51 @@ +## v2.0.0-rc.6 (2024-01-02) + +### Fix + +* fix: bump version ([`92c3f5d`](https://github.com/AI21Labs/ai21-python/commit/92c3f5de0b35510a3271a389160ab06cc064c7ea)) + +### Refactor + +* refactor: Add enums (#30) + +* refactor: answer enum + +* refactor: answer - mode enum + +* refactor: moved imports + +* refactor: Added enums to chat requests/response + +* refactor: Added enums to completion requests/response + +* fix: imports + +* refactor: Added embed types enum + +* refactor: Added correction type enum + +* refactor: Added improvement type enum + +* refactor: Added enums to paraphrase and library answer + +* refactor: Added enums to segmentation + +* refactor: Added enums to summary + +* refactor: Added enums to summary by segment + +* fix: test ([`f84f86a`](https://github.com/AI21Labs/ai21-python/commit/f84f86ab4992701ba8ff22b262317e1b336ea785)) + + ## v2.0.0-rc.5 (2023-12-27) +### Chore + +* chore(release): v2.0.0-rc.5 [skip ci] ([`916c7b4`](https://github.com/AI21Labs/ai21-python/commit/916c7b40395eb2678cf4e66d49206fc94bbe9b73)) + ### Ci * ci: Remove python 3_7 support (#25) diff --git a/ai21/version.py b/ai21/version.py index ba278f7c..e5fcb364 100644 --- a/ai21/version.py +++ b/ai21/version.py @@ -1 +1 @@ -VERSION = "2.0.0-rc.5" +VERSION = "2.0.0-rc.6" diff --git a/pyproject.toml b/pyproject.toml index 9fec5885..53b2f7a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ exclude_lines = [ [tool.poetry] name = "ai21" -version = "2.0.0-rc.5" +version = "2.0.0-rc.6" description = "" authors = ["AI21 Labs"] readme = "README.md" From 9e8a1f05dc4afaecfc525fa22cb76594d6cf0c8d Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Tue, 2 Jan 2024 16:15:58 +0200 Subject: [PATCH 14/45] fix: Restructure packages (#31) * refactor: moved classes to models package * refactor: moved responses = to models package * refactor: moved resources to common package * refactor: chat message rename * refactor: init * fix: imports * refactor: added more to imports --- ai21/__init__.py | 30 -------- ai21/ai21_http_client.py | 2 +- .../bedrock/resources/bedrock_completion.py | 4 +- .../bedrock}/resources/bedrock_resource.py | 0 .../bases => clients/common}/__init__.py | 0 .../bases => clients/common}/answer_base.py | 5 +- .../bases => clients/common}/chat_base.py | 19 ++--- .../common}/completion_base.py | 4 +- .../common}/custom_model_base.py | 2 +- .../bases => clients/common}/dataset_base.py | 2 +- .../bases => clients/common}/embed_base.py | 9 +-- .../bases => clients/common}/gec_base.py | 2 +- .../common}/improvements_base.py | 4 +- .../common}/paraphrase_base.py | 4 +- .../common}/segmentation_base.py | 4 +- .../common}/summarize_base.py | 4 +- .../common}/summarize_by_segment_base.py | 4 +- .../sagemaker/resources/sagemaker_answer.py | 6 +- .../resources/sagemaker_completion.py | 4 +- .../sagemaker/resources/sagemaker_gec.py | 6 +- .../resources/sagemaker_paraphrase.py | 8 +- .../resources/sagemaker_resource.py | 0 .../resources/sagemaker_summarize.py | 8 +- .../clients/studio/resources/studio_answer.py | 7 +- ai21/clients/studio/resources/studio_chat.py | 9 ++- .../studio/resources/studio_completion.py | 8 +- .../studio/resources/studio_custom_model.py | 6 +- .../studio/resources/studio_dataset.py | 6 +- ai21/clients/studio/resources/studio_embed.py | 6 +- ai21/clients/studio/resources/studio_gec.py | 6 +- .../studio/resources/studio_improvements.py | 8 +- .../studio/resources/studio_library.py | 11 ++- .../studio/resources/studio_paraphrase.py | 8 +- .../studio}/resources/studio_resource.py | 0 .../studio/resources/studio_segmentation.py | 8 +- .../studio/resources/studio_summarize.py | 8 +- .../resources/studio_summarize_by_segment.py | 8 +- ai21/models/__init__.py | 76 +++++++++++++++++++ ai21/{resources => }/models/answer_length.py | 0 ai21/models/chat_message.py | 12 +++ ai21/{resources => }/models/document_type.py | 0 ai21/models/embed_type.py | 6 ++ .../models/improvement_type.py | 0 ai21/{resources => }/models/mode.py | 0 .../models/paraphrase_style_type.py | 0 ai21/{resources => }/models/penalty.py | 0 .../models => models/responses}/__init__.py | 0 .../responses/answer_response.py | 4 +- .../responses/chat_response.py | 2 +- .../responses/completion_response.py | 0 .../responses/custom_model_response.py | 0 .../responses/dataset_response.py | 0 .../responses/embed_response.py | 0 .../responses/file_response.py | 0 .../responses/gec_response.py | 0 .../responses/improvement_response.py | 0 .../responses/library_answer_response.py | 0 .../responses/library_search_response.py | 0 .../responses/paraphrase_response.py | 0 .../responses/segmentation_response.py | 0 .../summarize_by_segment_response.py | 0 .../responses/summarize_response.py | 0 ai21/{resources => }/models/role_type.py | 0 ai21/{resources => }/models/summary_method.py | 0 ai21/resources/__init__.py | 23 ------ ai21/resources/responses/__init__.py | 0 .../clients/studio/resources/conftest.py | 17 +++-- .../studio/resources/test_studio_resources.py | 4 +- 68 files changed, 203 insertions(+), 171 deletions(-) rename ai21/{ => clients/bedrock}/resources/bedrock_resource.py (100%) rename ai21/{resources/bases => clients/common}/__init__.py (100%) rename ai21/{resources/bases => clients/common}/answer_base.py (81%) rename ai21/{resources/bases => clients/common}/chat_base.py (82%) rename ai21/{resources/bases => clients/common}/completion_base.py (94%) rename ai21/{resources/bases => clients/common}/custom_model_base.py (93%) rename ai21/{resources/bases => clients/common}/dataset_base.py (94%) rename ai21/{resources/bases => clients/common}/embed_base.py (76%) rename ai21/{resources/bases => clients/common}/gec_base.py (86%) rename ai21/{resources/bases => clients/common}/improvements_base.py (78%) rename ai21/{resources/bases => clients/common}/paraphrase_base.py (85%) rename ai21/{resources/bases => clients/common}/segmentation_base.py (78%) rename ai21/{resources/bases => clients/common}/summarize_base.py (85%) rename ai21/{resources/bases => clients/common}/summarize_by_segment_base.py (86%) rename ai21/{ => clients/sagemaker}/resources/sagemaker_resource.py (100%) rename ai21/{ => clients/studio}/resources/studio_resource.py (100%) rename ai21/{resources => }/models/answer_length.py (100%) create mode 100644 ai21/models/chat_message.py rename ai21/{resources => }/models/document_type.py (100%) create mode 100644 ai21/models/embed_type.py rename ai21/{resources => }/models/improvement_type.py (100%) rename ai21/{resources => }/models/mode.py (100%) rename ai21/{resources => }/models/paraphrase_style_type.py (100%) rename ai21/{resources => }/models/penalty.py (100%) rename ai21/{resources/models => models/responses}/__init__.py (100%) rename ai21/{resources => models}/responses/answer_response.py (71%) rename ai21/{resources => models}/responses/chat_response.py (89%) rename ai21/{resources => models}/responses/completion_response.py (100%) rename ai21/{resources => models}/responses/custom_model_response.py (100%) rename ai21/{resources => models}/responses/dataset_response.py (100%) rename ai21/{resources => models}/responses/embed_response.py (100%) rename ai21/{resources => models}/responses/file_response.py (100%) rename ai21/{resources => models}/responses/gec_response.py (100%) rename ai21/{resources => models}/responses/improvement_response.py (100%) rename ai21/{resources => models}/responses/library_answer_response.py (100%) rename ai21/{resources => models}/responses/library_search_response.py (100%) rename ai21/{resources => models}/responses/paraphrase_response.py (100%) rename ai21/{resources => models}/responses/segmentation_response.py (100%) rename ai21/{resources => models}/responses/summarize_by_segment_response.py (100%) rename ai21/{resources => models}/responses/summarize_response.py (100%) rename ai21/{resources => }/models/role_type.py (100%) rename ai21/{resources => }/models/summary_method.py (100%) delete mode 100644 ai21/resources/__init__.py delete mode 100644 ai21/resources/responses/__init__.py diff --git a/ai21/__init__.py b/ai21/__init__.py index 6c5fb3e9..bfc95fa8 100644 --- a/ai21/__init__.py +++ b/ai21/__init__.py @@ -10,21 +10,6 @@ TooManyRequestsError, ) from ai21.logger import setup_logger -from ai21.resources.responses.answer_response import AnswerResponse -from ai21.resources.responses.chat_response import ChatResponse -from ai21.resources.responses.completion_response import CompletionsResponse -from ai21.resources.responses.custom_model_response import CustomBaseModelResponse -from ai21.resources.responses.dataset_response import DatasetResponse -from ai21.resources.responses.embed_response import EmbedResponse -from ai21.resources.responses.file_response import FileResponse -from ai21.resources.responses.gec_response import GECResponse -from ai21.resources.responses.improvement_response import ImprovementsResponse -from ai21.resources.responses.library_answer_response import LibraryAnswerResponse -from ai21.resources.responses.library_search_response import LibrarySearchResponse -from ai21.resources.responses.paraphrase_response import ParaphraseResponse -from ai21.resources.responses.segmentation_response import SegmentationResponse -from ai21.resources.responses.summarize_by_segment_response import SummarizeBySegmentResponse -from ai21.resources.responses.summarize_response import SummarizeResponse from ai21.services.sagemaker import SageMaker from ai21.version import VERSION @@ -75,20 +60,5 @@ def __getattr__(name: str) -> Any: "AI21BedrockClient", "AI21SageMakerClient", "BedrockModelID", - "AnswerResponse", - "ChatResponse", - "CompletionsResponse", - "CustomBaseModelResponse", - "DatasetResponse", - "EmbedResponse", - "FileResponse", - "GECResponse", - "ImprovementsResponse", - "LibraryAnswerResponse", - "LibrarySearchResponse", - "ParaphraseResponse", "SageMaker", - "SegmentationResponse", - "SummarizeBySegmentResponse", - "SummarizeResponse", ] diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index 68007654..995f1f19 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -25,7 +25,7 @@ def __init__( self._env_config = env_config self._api_key = api_key or self._env_config.api_key - if self._api_key is None: + if not self._api_key: raise MissingApiKeyError() self._api_host = api_host or self._env_config.api_host diff --git a/ai21/clients/bedrock/resources/bedrock_completion.py b/ai21/clients/bedrock/resources/bedrock_completion.py index 150cf381..f9ff4646 100644 --- a/ai21/clients/bedrock/resources/bedrock_completion.py +++ b/ai21/clients/bedrock/resources/bedrock_completion.py @@ -1,8 +1,8 @@ from typing import Optional, List from ai21.resources import Penalty -from ai21.resources.bedrock_resource import BedrockResource -from ai21.resources.responses.completion_response import CompletionsResponse +from ai21.clients.bedrock.resources.bedrock_resource import BedrockResource +from ai21.models.responses.completion_response import CompletionsResponse class BedrockCompletion(BedrockResource): diff --git a/ai21/resources/bedrock_resource.py b/ai21/clients/bedrock/resources/bedrock_resource.py similarity index 100% rename from ai21/resources/bedrock_resource.py rename to ai21/clients/bedrock/resources/bedrock_resource.py diff --git a/ai21/resources/bases/__init__.py b/ai21/clients/common/__init__.py similarity index 100% rename from ai21/resources/bases/__init__.py rename to ai21/clients/common/__init__.py diff --git a/ai21/resources/bases/answer_base.py b/ai21/clients/common/answer_base.py similarity index 81% rename from ai21/resources/bases/answer_base.py rename to ai21/clients/common/answer_base.py index 96663084..a1543646 100644 --- a/ai21/resources/bases/answer_base.py +++ b/ai21/clients/common/answer_base.py @@ -1,9 +1,8 @@ from abc import ABC from typing import Optional, Any, Dict -from ai21.resources.models.answer_length import AnswerLength -from ai21.resources.models.mode import Mode -from ai21.resources.responses.answer_response import AnswerResponse +from ai21.models import Mode, AnswerLength +from ai21.models.responses.answer_response import AnswerResponse class Answer(ABC): diff --git a/ai21/resources/bases/chat_base.py b/ai21/clients/common/chat_base.py similarity index 82% rename from ai21/resources/bases/chat_base.py rename to ai21/clients/common/chat_base.py index 54854bb7..869ddfaa 100644 --- a/ai21/resources/bases/chat_base.py +++ b/ai21/clients/common/chat_base.py @@ -1,18 +1,9 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import List, Any, Dict, Optional -from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin -from ai21.resources.models.penalty import Penalty -from ai21.resources.models.role_type import RoleType -from ai21.resources.responses.chat_response import ChatResponse - - -@dataclass -class Message(AI21BaseModelMixin): - role: RoleType - text: str - name: Optional[str] +from ai21.models.chat_message import ChatMessage +from ai21.models.penalty import Penalty +from ai21.models.responses.chat_response import ChatResponse class Chat(ABC): @@ -22,7 +13,7 @@ class Chat(ABC): def create( self, model: str, - messages: List[Message], + messages: List[ChatMessage], system: str, *, num_results: Optional[int] = 1, @@ -45,7 +36,7 @@ def _json_to_response(self, json: Dict[str, Any]) -> ChatResponse: def _create_body( self, model: str, - messages: List[Message], + messages: List[ChatMessage], system: str, num_results: Optional[int] = 1, temperature: Optional[float] = 0.7, diff --git a/ai21/resources/bases/completion_base.py b/ai21/clients/common/completion_base.py similarity index 94% rename from ai21/resources/bases/completion_base.py rename to ai21/clients/common/completion_base.py index df9f2352..a2bc8c3d 100644 --- a/ai21/resources/bases/completion_base.py +++ b/ai21/clients/common/completion_base.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod from typing import Optional, List, Dict, Any -from ai21.resources import Penalty -from ai21.resources.responses.completion_response import CompletionsResponse +from ai21.models import Penalty +from ai21.models.responses.completion_response import CompletionsResponse class Completion(ABC): diff --git a/ai21/resources/bases/custom_model_base.py b/ai21/clients/common/custom_model_base.py similarity index 93% rename from ai21/resources/bases/custom_model_base.py rename to ai21/clients/common/custom_model_base.py index 6e360a17..7d3b55ae 100644 --- a/ai21/resources/bases/custom_model_base.py +++ b/ai21/clients/common/custom_model_base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional, List, Any, Dict -from ai21.resources.responses.custom_model_response import CustomBaseModelResponse +from ai21.models.responses.custom_model_response import CustomBaseModelResponse class CustomModel(ABC): diff --git a/ai21/resources/bases/dataset_base.py b/ai21/clients/common/dataset_base.py similarity index 94% rename from ai21/resources/bases/dataset_base.py rename to ai21/clients/common/dataset_base.py index 2be49fc7..9fa57f85 100644 --- a/ai21/resources/bases/dataset_base.py +++ b/ai21/clients/common/dataset_base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional, Any, Dict -from ai21.resources.responses.dataset_response import DatasetResponse +from ai21.models.responses.dataset_response import DatasetResponse class Dataset(ABC): diff --git a/ai21/resources/bases/embed_base.py b/ai21/clients/common/embed_base.py similarity index 76% rename from ai21/resources/bases/embed_base.py rename to ai21/clients/common/embed_base.py index ee07187b..baadd4ec 100644 --- a/ai21/resources/bases/embed_base.py +++ b/ai21/clients/common/embed_base.py @@ -1,13 +1,8 @@ from abc import ABC, abstractmethod -from enum import Enum from typing import List, Any, Dict, Optional -from ai21.resources.responses.embed_response import EmbedResponse - - -class EmbedType(str, Enum): - QUERY = "query" - SEGMENT = "segment" +from ai21.models.embed_type import EmbedType +from ai21.models.responses.embed_response import EmbedResponse class Embed(ABC): diff --git a/ai21/resources/bases/gec_base.py b/ai21/clients/common/gec_base.py similarity index 86% rename from ai21/resources/bases/gec_base.py rename to ai21/clients/common/gec_base.py index 0d623dfd..8de743e2 100644 --- a/ai21/resources/bases/gec_base.py +++ b/ai21/clients/common/gec_base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, Any -from ai21.resources.responses.gec_response import GECResponse +from ai21.models.responses.gec_response import GECResponse class GEC(ABC): diff --git a/ai21/resources/bases/improvements_base.py b/ai21/clients/common/improvements_base.py similarity index 78% rename from ai21/resources/bases/improvements_base.py rename to ai21/clients/common/improvements_base.py index 69ac49e7..df912e1d 100644 --- a/ai21/resources/bases/improvements_base.py +++ b/ai21/clients/common/improvements_base.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List -from ai21.resources.models.improvement_type import ImprovementType -from ai21.resources.responses.improvement_response import ImprovementsResponse +from ai21.models import ImprovementType +from ai21.models.responses.improvement_response import ImprovementsResponse class Improvements(ABC): diff --git a/ai21/resources/bases/paraphrase_base.py b/ai21/clients/common/paraphrase_base.py similarity index 85% rename from ai21/resources/bases/paraphrase_base.py rename to ai21/clients/common/paraphrase_base.py index aa2e2e7d..917cdd75 100644 --- a/ai21/resources/bases/paraphrase_base.py +++ b/ai21/clients/common/paraphrase_base.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod from typing import Optional, Any, Dict -from ai21.resources.models.paraphrase_style_type import ParaphraseStyleType -from ai21.resources.responses.paraphrase_response import ParaphraseResponse +from ai21.models import ParaphraseStyleType +from ai21.models.responses.paraphrase_response import ParaphraseResponse class Paraphrase(ABC): diff --git a/ai21/resources/bases/segmentation_base.py b/ai21/clients/common/segmentation_base.py similarity index 78% rename from ai21/resources/bases/segmentation_base.py rename to ai21/clients/common/segmentation_base.py index f20cecb3..c4f658a9 100644 --- a/ai21/resources/bases/segmentation_base.py +++ b/ai21/clients/common/segmentation_base.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod from typing import Any, Dict -from ai21.resources.models.document_type import DocumentType -from ai21.resources.responses.segmentation_response import SegmentationResponse +from ai21.models.document_type import DocumentType +from ai21.models.responses.segmentation_response import SegmentationResponse class Segmentation(ABC): diff --git a/ai21/resources/bases/summarize_base.py b/ai21/clients/common/summarize_base.py similarity index 85% rename from ai21/resources/bases/summarize_base.py rename to ai21/clients/common/summarize_base.py index 435549a5..85cec2f5 100644 --- a/ai21/resources/bases/summarize_base.py +++ b/ai21/clients/common/summarize_base.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod from typing import Optional, Any, Dict -from ai21.resources.models.summary_method import SummaryMethod -from ai21.resources.responses.summarize_response import SummarizeResponse +from ai21.models.responses.summarize_response import SummarizeResponse +from ai21.models.summary_method import SummaryMethod class Summarize(ABC): diff --git a/ai21/resources/bases/summarize_by_segment_base.py b/ai21/clients/common/summarize_by_segment_base.py similarity index 86% rename from ai21/resources/bases/summarize_by_segment_base.py rename to ai21/clients/common/summarize_by_segment_base.py index ac30d1ae..40a4abfa 100644 --- a/ai21/resources/bases/summarize_by_segment_base.py +++ b/ai21/clients/common/summarize_by_segment_base.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod from typing import Optional, Any, Dict -from ai21.resources.models.document_type import DocumentType -from ai21.resources.responses.summarize_by_segment_response import ( +from ai21.models.document_type import DocumentType +from ai21.models.responses.summarize_by_segment_response import ( SummarizeBySegmentResponse, ) diff --git a/ai21/clients/sagemaker/resources/sagemaker_answer.py b/ai21/clients/sagemaker/resources/sagemaker_answer.py index 1584abc9..f344d6b0 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_answer.py +++ b/ai21/clients/sagemaker/resources/sagemaker_answer.py @@ -1,8 +1,8 @@ from typing import Optional -from ai21.resources.bases.answer_base import Answer, AnswerLength, Mode -from ai21.resources.responses.answer_response import AnswerResponse -from ai21.resources.sagemaker_resource import SageMakerResource +from ai21.clients.common.answer_base import Answer, AnswerLength, Mode +from ai21.models.responses.answer_response import AnswerResponse +from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource class SageMakerAnswer(SageMakerResource, Answer): diff --git a/ai21/clients/sagemaker/resources/sagemaker_completion.py b/ai21/clients/sagemaker/resources/sagemaker_completion.py index 687f9b94..373fdcb8 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_completion.py +++ b/ai21/clients/sagemaker/resources/sagemaker_completion.py @@ -1,8 +1,8 @@ from typing import Optional, List from ai21.resources import Penalty -from ai21.resources.responses.completion_response import CompletionsResponse -from ai21.resources.sagemaker_resource import SageMakerResource +from ai21.models.responses.completion_response import CompletionsResponse +from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource class SageMakerCompletion(SageMakerResource): diff --git a/ai21/clients/sagemaker/resources/sagemaker_gec.py b/ai21/clients/sagemaker/resources/sagemaker_gec.py index db8717e8..0750a7ea 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_gec.py +++ b/ai21/clients/sagemaker/resources/sagemaker_gec.py @@ -1,6 +1,6 @@ -from ai21.resources.bases.gec_base import GEC -from ai21.resources.responses.gec_response import GECResponse -from ai21.resources.sagemaker_resource import SageMakerResource +from ai21.clients.common.gec_base import GEC +from ai21.models.responses.gec_response import GECResponse +from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource class SageMakerGEC(SageMakerResource, GEC): diff --git a/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py b/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py index 49d9ce5d..b2588019 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py +++ b/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py @@ -1,9 +1,9 @@ from typing import Optional -from ai21.resources.bases.paraphrase_base import Paraphrase -from ai21.resources.models.paraphrase_style_type import ParaphraseStyleType -from ai21.resources.responses.paraphrase_response import ParaphraseResponse -from ai21.resources.sagemaker_resource import SageMakerResource +from ai21.clients.common.paraphrase_base import Paraphrase +from ai21.models.paraphrase_style_type import ParaphraseStyleType +from ai21.models.responses.paraphrase_response import ParaphraseResponse +from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource class SageMakerParaphrase(SageMakerResource, Paraphrase): diff --git a/ai21/resources/sagemaker_resource.py b/ai21/clients/sagemaker/resources/sagemaker_resource.py similarity index 100% rename from ai21/resources/sagemaker_resource.py rename to ai21/clients/sagemaker/resources/sagemaker_resource.py diff --git a/ai21/clients/sagemaker/resources/sagemaker_summarize.py b/ai21/clients/sagemaker/resources/sagemaker_summarize.py index c7db557b..1d5a7bc2 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_summarize.py +++ b/ai21/clients/sagemaker/resources/sagemaker_summarize.py @@ -2,10 +2,10 @@ from typing import Optional -from ai21.resources.bases.summarize_base import Summarize -from ai21.resources.models.summary_method import SummaryMethod -from ai21.resources.responses.summarize_response import SummarizeResponse -from ai21.resources.sagemaker_resource import SageMakerResource +from ai21.clients.common.summarize_base import Summarize +from ai21.models.summary_method import SummaryMethod +from ai21.models.responses import SummarizeResponse +from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource class SageMakerSummarize(SageMakerResource, Summarize): diff --git a/ai21/clients/studio/resources/studio_answer.py b/ai21/clients/studio/resources/studio_answer.py index 3962353c..a8cfc13e 100644 --- a/ai21/clients/studio/resources/studio_answer.py +++ b/ai21/clients/studio/resources/studio_answer.py @@ -1,8 +1,9 @@ from typing import Optional -from ai21.resources.bases.answer_base import Answer, AnswerLength, Mode -from ai21.resources.responses.answer_response import AnswerResponse -from ai21.resources.studio_resource import StudioResource +from ai21.clients.common.answer_base import Answer +from ai21.models import AnswerLength, Mode +from ai21.models.responses.answer_response import AnswerResponse +from ai21.clients.studio.resources.studio_resource import StudioResource class StudioAnswer(StudioResource, Answer): diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index f1dab12b..710ed308 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -1,15 +1,16 @@ from typing import List, Any, Optional, Dict -from ai21.resources.bases.chat_base import Chat, Message -from ai21.resources.responses.chat_response import ChatResponse -from ai21.resources.studio_resource import StudioResource +from ai21.clients.common.chat_base import Chat +from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models.chat_message import ChatMessage +from ai21.models.responses.chat_response import ChatResponse class StudioChat(StudioResource, Chat): def create( self, model: str, - messages: List[Message], + messages: List[ChatMessage], system: str, *, num_results: Optional[int] = 1, diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index bcac84f8..75364113 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -1,9 +1,9 @@ from typing import Optional, List -from ai21.resources import Penalty -from ai21.resources.bases.completion_base import Completion -from ai21.resources.responses.completion_response import CompletionsResponse -from ai21.resources.studio_resource import StudioResource +from ai21.clients.common.completion_base import Completion +from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models import Penalty +from ai21.models.responses.completion_response import CompletionsResponse class StudioCompletion(StudioResource, Completion): diff --git a/ai21/clients/studio/resources/studio_custom_model.py b/ai21/clients/studio/resources/studio_custom_model.py index 32166e53..831b4575 100644 --- a/ai21/clients/studio/resources/studio_custom_model.py +++ b/ai21/clients/studio/resources/studio_custom_model.py @@ -1,8 +1,8 @@ from typing import List, Optional -from ai21.resources.bases.custom_model_base import CustomModel -from ai21.resources.responses.custom_model_response import CustomBaseModelResponse -from ai21.resources.studio_resource import StudioResource +from ai21.clients.common.custom_model_base import CustomModel +from ai21.models.responses.custom_model_response import CustomBaseModelResponse +from ai21.clients.studio.resources.studio_resource import StudioResource class StudioCustomModel(StudioResource, CustomModel): diff --git a/ai21/clients/studio/resources/studio_dataset.py b/ai21/clients/studio/resources/studio_dataset.py index 8626d71b..ccfb4bac 100644 --- a/ai21/clients/studio/resources/studio_dataset.py +++ b/ai21/clients/studio/resources/studio_dataset.py @@ -1,8 +1,8 @@ from typing import Optional, List -from ai21.resources.bases.dataset_base import Dataset -from ai21.resources.responses.dataset_response import DatasetResponse -from ai21.resources.studio_resource import StudioResource +from ai21.clients.common.dataset_base import Dataset +from ai21.models.responses.dataset_response import DatasetResponse +from ai21.clients.studio.resources.studio_resource import StudioResource class StudioDataset(StudioResource, Dataset): diff --git a/ai21/clients/studio/resources/studio_embed.py b/ai21/clients/studio/resources/studio_embed.py index 80cad1c7..7495ea67 100644 --- a/ai21/clients/studio/resources/studio_embed.py +++ b/ai21/clients/studio/resources/studio_embed.py @@ -1,8 +1,8 @@ from typing import List, Optional -from ai21.resources.bases.embed_base import Embed -from ai21.resources.responses.embed_response import EmbedResponse -from ai21.resources.studio_resource import StudioResource +from ai21.clients.common.embed_base import Embed +from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models.responses.embed_response import EmbedResponse class StudioEmbed(StudioResource, Embed): diff --git a/ai21/clients/studio/resources/studio_gec.py b/ai21/clients/studio/resources/studio_gec.py index 3ce45a6c..3e716e7a 100644 --- a/ai21/clients/studio/resources/studio_gec.py +++ b/ai21/clients/studio/resources/studio_gec.py @@ -1,6 +1,6 @@ -from ai21.resources.bases.gec_base import GEC -from ai21.resources.responses.gec_response import GECResponse -from ai21.resources.studio_resource import StudioResource +from ai21.clients.common.gec_base import GEC +from ai21.models.responses.gec_response import GECResponse +from ai21.clients.studio.resources.studio_resource import StudioResource class StudioGEC(StudioResource, GEC): diff --git a/ai21/clients/studio/resources/studio_improvements.py b/ai21/clients/studio/resources/studio_improvements.py index 2e17cfd9..513a413c 100644 --- a/ai21/clients/studio/resources/studio_improvements.py +++ b/ai21/clients/studio/resources/studio_improvements.py @@ -1,10 +1,10 @@ from typing import List +from ai21.clients.common.improvements_base import Improvements +from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.errors import EmptyMandatoryListError -from ai21.resources.bases.improvements_base import Improvements -from ai21.resources.models.improvement_type import ImprovementType -from ai21.resources.responses.improvement_response import ImprovementsResponse -from ai21.resources.studio_resource import StudioResource +from ai21.models import ImprovementType +from ai21.models.responses.improvement_response import ImprovementsResponse class StudioImprovements(StudioResource, Improvements): diff --git a/ai21/clients/studio/resources/studio_library.py b/ai21/clients/studio/resources/studio_library.py index 1962cd58..fa73b123 100644 --- a/ai21/clients/studio/resources/studio_library.py +++ b/ai21/clients/studio/resources/studio_library.py @@ -1,12 +1,11 @@ from typing import Optional, List from ai21.ai21_http_client import AI21HTTPClient -from ai21.resources.models.answer_length import AnswerLength -from ai21.resources.models.mode import Mode -from ai21.resources.responses.file_response import FileResponse -from ai21.resources.responses.library_answer_response import LibraryAnswerResponse -from ai21.resources.responses.library_search_response import LibrarySearchResponse -from ai21.resources.studio_resource import StudioResource +from ai21.models import Mode, AnswerLength +from ai21.models.responses.file_response import FileResponse +from ai21.models.responses.library_answer_response import LibraryAnswerResponse +from ai21.models.responses.library_search_response import LibrarySearchResponse +from ai21.clients.studio.resources.studio_resource import StudioResource class StudioLibrary(StudioResource): diff --git a/ai21/clients/studio/resources/studio_paraphrase.py b/ai21/clients/studio/resources/studio_paraphrase.py index 1dbf06fb..686841db 100644 --- a/ai21/clients/studio/resources/studio_paraphrase.py +++ b/ai21/clients/studio/resources/studio_paraphrase.py @@ -1,9 +1,9 @@ from typing import Optional -from ai21.resources.bases.paraphrase_base import Paraphrase -from ai21.resources.models.paraphrase_style_type import ParaphraseStyleType -from ai21.resources.responses.paraphrase_response import ParaphraseResponse -from ai21.resources.studio_resource import StudioResource +from ai21.clients.common.paraphrase_base import Paraphrase +from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models import ParaphraseStyleType +from ai21.models.responses.paraphrase_response import ParaphraseResponse class StudioParaphrase(StudioResource, Paraphrase): diff --git a/ai21/resources/studio_resource.py b/ai21/clients/studio/resources/studio_resource.py similarity index 100% rename from ai21/resources/studio_resource.py rename to ai21/clients/studio/resources/studio_resource.py diff --git a/ai21/clients/studio/resources/studio_segmentation.py b/ai21/clients/studio/resources/studio_segmentation.py index 0586561f..dbda4225 100644 --- a/ai21/clients/studio/resources/studio_segmentation.py +++ b/ai21/clients/studio/resources/studio_segmentation.py @@ -1,7 +1,7 @@ -from ai21.resources.bases.segmentation_base import Segmentation -from ai21.resources.models.document_type import DocumentType -from ai21.resources.responses.segmentation_response import SegmentationResponse -from ai21.resources.studio_resource import StudioResource +from ai21.clients.common.segmentation_base import Segmentation +from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models.document_type import DocumentType +from ai21.models.responses.segmentation_response import SegmentationResponse class StudioSegmentation(StudioResource, Segmentation): diff --git a/ai21/clients/studio/resources/studio_summarize.py b/ai21/clients/studio/resources/studio_summarize.py index 1b9c44d3..b2b5f860 100644 --- a/ai21/clients/studio/resources/studio_summarize.py +++ b/ai21/clients/studio/resources/studio_summarize.py @@ -1,9 +1,9 @@ from typing import Optional -from ai21.resources.bases.summarize_base import Summarize -from ai21.resources.models.summary_method import SummaryMethod -from ai21.resources.responses.summarize_response import SummarizeResponse -from ai21.resources.studio_resource import StudioResource +from ai21.clients.common.summarize_base import Summarize +from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models.responses.summarize_response import SummarizeResponse +from ai21.models.summary_method import SummaryMethod class StudioSummarize(StudioResource, Summarize): diff --git a/ai21/clients/studio/resources/studio_summarize_by_segment.py b/ai21/clients/studio/resources/studio_summarize_by_segment.py index 7644a5f4..8292a1f9 100644 --- a/ai21/clients/studio/resources/studio_summarize_by_segment.py +++ b/ai21/clients/studio/resources/studio_summarize_by_segment.py @@ -1,11 +1,11 @@ from typing import Optional -from ai21.resources.bases.summarize_by_segment_base import SummarizeBySegment -from ai21.resources.models.document_type import DocumentType -from ai21.resources.responses.summarize_by_segment_response import ( +from ai21.clients.common.summarize_by_segment_base import SummarizeBySegment +from ai21.models.document_type import DocumentType +from ai21.models.responses.summarize_by_segment_response import ( SummarizeBySegmentResponse, ) -from ai21.resources.studio_resource import StudioResource +from ai21.clients.studio.resources.studio_resource import StudioResource class StudioSummarizeBySegment(StudioResource, SummarizeBySegment): diff --git a/ai21/models/__init__.py b/ai21/models/__init__.py index e69de29b..5e5877d8 100644 --- a/ai21/models/__init__.py +++ b/ai21/models/__init__.py @@ -0,0 +1,76 @@ +from ai21.models.answer_length import AnswerLength +from ai21.models.chat_message import ChatMessage +from ai21.models.document_type import DocumentType +from ai21.models.embed_type import EmbedType +from ai21.models.improvement_type import ImprovementType +from ai21.models.mode import Mode +from ai21.models.paraphrase_style_type import ParaphraseStyleType +from ai21.models.penalty import Penalty +from ai21.models.responses.answer_response import AnswerResponse +from ai21.models.responses.chat_response import ChatResponse, ChatOutput, FinishReason +from ai21.models.responses.completion_response import ( + CompletionsResponse, + Completion, + CompletionFinishReason, + CompletionData, + Prompt, +) +from ai21.models.responses.custom_model_response import CustomBaseModelResponse, BaseModelMetadata +from ai21.models.responses.dataset_response import DatasetResponse +from ai21.models.responses.embed_response import EmbedResponse, EmbedResult +from ai21.models.responses.file_response import FileResponse +from ai21.models.responses.gec_response import GECResponse, Correction, CorrectionType +from ai21.models.responses.improvement_response import ImprovementsResponse, Improvement +from ai21.models.responses.library_answer_response import LibraryAnswerResponse, SourceDocument +from ai21.models.responses.library_search_response import LibrarySearchResponse, LibrarySearchResult +from ai21.models.responses.paraphrase_response import ParaphraseResponse, Suggestion +from ai21.models.responses.segmentation_response import SegmentationResponse +from ai21.models.responses.summarize_by_segment_response import SummarizeBySegmentResponse, SegmentSummary, Highlight +from ai21.models.responses.summarize_response import SummarizeResponse +from ai21.models.role_type import RoleType +from ai21.models.summary_method import SummaryMethod + + +__all__ = [ + "AnswerLength", + "Mode", + "ChatMessage", + "RoleType", + "Penalty", + "EmbedType", + "ImprovementType", + "ParaphraseStyleType", + "DocumentType", + "SummaryMethod", + "AnswerResponse", + "ChatResponse", + "ChatOutput", + "FinishReason", + "CompletionsResponse", + "Completion", + "CompletionFinishReason", + "CompletionData", + "Prompt", + "CustomBaseModelResponse", + "BaseModelMetadata", + "DatasetResponse", + "EmbedResponse", + "EmbedResult", + "FileResponse", + "GECResponse", + "Correction", + "CorrectionType", + "ImprovementsResponse", + "Improvement", + "LibraryAnswerResponse", + "SourceDocument", + "LibrarySearchResponse", + "LibrarySearchResult", + "ParaphraseResponse", + "Suggestion", + "SegmentationResponse", + "SegmentSummary", + "Highlight", + "SummarizeBySegmentResponse", + "SummarizeResponse", +] diff --git a/ai21/resources/models/answer_length.py b/ai21/models/answer_length.py similarity index 100% rename from ai21/resources/models/answer_length.py rename to ai21/models/answer_length.py diff --git a/ai21/models/chat_message.py b/ai21/models/chat_message.py new file mode 100644 index 00000000..ebf719e3 --- /dev/null +++ b/ai21/models/chat_message.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass +from typing import Optional + +from ai21.models.role_type import RoleType +from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin + + +@dataclass +class ChatMessage(AI21BaseModelMixin): + role: RoleType + text: str + name: Optional[str] = None diff --git a/ai21/resources/models/document_type.py b/ai21/models/document_type.py similarity index 100% rename from ai21/resources/models/document_type.py rename to ai21/models/document_type.py diff --git a/ai21/models/embed_type.py b/ai21/models/embed_type.py new file mode 100644 index 00000000..d1268a86 --- /dev/null +++ b/ai21/models/embed_type.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class EmbedType(str, Enum): + QUERY = "query" + SEGMENT = "segment" diff --git a/ai21/resources/models/improvement_type.py b/ai21/models/improvement_type.py similarity index 100% rename from ai21/resources/models/improvement_type.py rename to ai21/models/improvement_type.py diff --git a/ai21/resources/models/mode.py b/ai21/models/mode.py similarity index 100% rename from ai21/resources/models/mode.py rename to ai21/models/mode.py diff --git a/ai21/resources/models/paraphrase_style_type.py b/ai21/models/paraphrase_style_type.py similarity index 100% rename from ai21/resources/models/paraphrase_style_type.py rename to ai21/models/paraphrase_style_type.py diff --git a/ai21/resources/models/penalty.py b/ai21/models/penalty.py similarity index 100% rename from ai21/resources/models/penalty.py rename to ai21/models/penalty.py diff --git a/ai21/resources/models/__init__.py b/ai21/models/responses/__init__.py similarity index 100% rename from ai21/resources/models/__init__.py rename to ai21/models/responses/__init__.py diff --git a/ai21/resources/responses/answer_response.py b/ai21/models/responses/answer_response.py similarity index 71% rename from ai21/resources/responses/answer_response.py rename to ai21/models/responses/answer_response.py index 3056b6f2..437881b4 100644 --- a/ai21/resources/responses/answer_response.py +++ b/ai21/models/responses/answer_response.py @@ -7,5 +7,5 @@ @dataclass class AnswerResponse(AI21BaseModelMixin): id: str - answer_in_context: Optional[bool] - answer: Optional[str] + answer_in_context: Optional[bool] = None + answer: Optional[str] = None diff --git a/ai21/resources/responses/chat_response.py b/ai21/models/responses/chat_response.py similarity index 89% rename from ai21/resources/responses/chat_response.py rename to ai21/models/responses/chat_response.py index 81ce3062..e1a15a9f 100644 --- a/ai21/resources/responses/chat_response.py +++ b/ai21/models/responses/chat_response.py @@ -2,7 +2,7 @@ from typing import Optional, List from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin -from ai21.resources.models.role_type import RoleType +from ai21.models.role_type import RoleType @dataclass diff --git a/ai21/resources/responses/completion_response.py b/ai21/models/responses/completion_response.py similarity index 100% rename from ai21/resources/responses/completion_response.py rename to ai21/models/responses/completion_response.py diff --git a/ai21/resources/responses/custom_model_response.py b/ai21/models/responses/custom_model_response.py similarity index 100% rename from ai21/resources/responses/custom_model_response.py rename to ai21/models/responses/custom_model_response.py diff --git a/ai21/resources/responses/dataset_response.py b/ai21/models/responses/dataset_response.py similarity index 100% rename from ai21/resources/responses/dataset_response.py rename to ai21/models/responses/dataset_response.py diff --git a/ai21/resources/responses/embed_response.py b/ai21/models/responses/embed_response.py similarity index 100% rename from ai21/resources/responses/embed_response.py rename to ai21/models/responses/embed_response.py diff --git a/ai21/resources/responses/file_response.py b/ai21/models/responses/file_response.py similarity index 100% rename from ai21/resources/responses/file_response.py rename to ai21/models/responses/file_response.py diff --git a/ai21/resources/responses/gec_response.py b/ai21/models/responses/gec_response.py similarity index 100% rename from ai21/resources/responses/gec_response.py rename to ai21/models/responses/gec_response.py diff --git a/ai21/resources/responses/improvement_response.py b/ai21/models/responses/improvement_response.py similarity index 100% rename from ai21/resources/responses/improvement_response.py rename to ai21/models/responses/improvement_response.py diff --git a/ai21/resources/responses/library_answer_response.py b/ai21/models/responses/library_answer_response.py similarity index 100% rename from ai21/resources/responses/library_answer_response.py rename to ai21/models/responses/library_answer_response.py diff --git a/ai21/resources/responses/library_search_response.py b/ai21/models/responses/library_search_response.py similarity index 100% rename from ai21/resources/responses/library_search_response.py rename to ai21/models/responses/library_search_response.py diff --git a/ai21/resources/responses/paraphrase_response.py b/ai21/models/responses/paraphrase_response.py similarity index 100% rename from ai21/resources/responses/paraphrase_response.py rename to ai21/models/responses/paraphrase_response.py diff --git a/ai21/resources/responses/segmentation_response.py b/ai21/models/responses/segmentation_response.py similarity index 100% rename from ai21/resources/responses/segmentation_response.py rename to ai21/models/responses/segmentation_response.py diff --git a/ai21/resources/responses/summarize_by_segment_response.py b/ai21/models/responses/summarize_by_segment_response.py similarity index 100% rename from ai21/resources/responses/summarize_by_segment_response.py rename to ai21/models/responses/summarize_by_segment_response.py diff --git a/ai21/resources/responses/summarize_response.py b/ai21/models/responses/summarize_response.py similarity index 100% rename from ai21/resources/responses/summarize_response.py rename to ai21/models/responses/summarize_response.py diff --git a/ai21/resources/models/role_type.py b/ai21/models/role_type.py similarity index 100% rename from ai21/resources/models/role_type.py rename to ai21/models/role_type.py diff --git a/ai21/resources/models/summary_method.py b/ai21/models/summary_method.py similarity index 100% rename from ai21/resources/models/summary_method.py rename to ai21/models/summary_method.py diff --git a/ai21/resources/__init__.py b/ai21/resources/__init__.py deleted file mode 100644 index 626c7221..00000000 --- a/ai21/resources/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from ai21.resources.bases.chat_base import Message -from ai21.resources.bases.embed_base import EmbedType -from ai21.resources.models.answer_length import AnswerLength -from ai21.resources.models.document_type import DocumentType -from ai21.resources.models.improvement_type import ImprovementType -from ai21.resources.models.mode import Mode -from ai21.resources.models.paraphrase_style_type import ParaphraseStyleType -from ai21.resources.models.penalty import Penalty -from ai21.resources.models.role_type import RoleType -from ai21.resources.models.summary_method import SummaryMethod - -__all__ = [ - "AnswerLength", - "Mode", - "Message", - "RoleType", - "Penalty", - "EmbedType", - "ImprovementType", - "ParaphraseStyleType", - "DocumentType", - "SummaryMethod", -] diff --git a/ai21/resources/responses/__init__.py b/ai21/resources/responses/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index 8c50a082..a28e559d 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -1,14 +1,19 @@ import pytest from pytest_mock import MockerFixture -from ai21 import AnswerResponse, ChatResponse, CompletionsResponse from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_chat import StudioChat from ai21.clients.studio.resources.studio_completion import StudioCompletion -from ai21.resources import Message, RoleType -from ai21.resources.responses.chat_response import ChatOutput, FinishReason -from ai21.resources.responses.completion_response import Prompt, Completion, CompletionData, CompletionFinishReason +from ai21.models import AnswerResponse, ChatMessage, RoleType, ChatResponse +from ai21.models.responses.chat_response import ChatOutput, FinishReason +from ai21.models.responses.completion_response import ( + Prompt, + Completion, + CompletionData, + CompletionFinishReason, + CompletionsResponse, +) @pytest.fixture @@ -37,8 +42,8 @@ def get_studio_answer(): def get_studio_chat(): _DUMMY_MODEL = "dummy-chat-model" _DUMMY_MESSAGES = [ - Message(text="Hello, I need help with a signup process.", role=RoleType.USER, name="Alice"), - Message( + ChatMessage(text="Hello, I need help with a signup process.", role=RoleType.USER, name="Alice"), + ChatMessage( text="Hi Alice, I can help you with that. What seems to be the problem?", role=RoleType.ASSISTANT, name="Bob", diff --git a/tests/unittests/clients/studio/resources/test_studio_resources.py b/tests/unittests/clients/studio/resources/test_studio_resources.py index 0e4de3af..cc534fe2 100644 --- a/tests/unittests/clients/studio/resources/test_studio_resources.py +++ b/tests/unittests/clients/studio/resources/test_studio_resources.py @@ -2,10 +2,10 @@ import pytest -from ai21 import AnswerResponse from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.studio.resources.studio_answer import StudioAnswer -from ai21.resources.studio_resource import StudioResource +from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models.responses.answer_response import AnswerResponse from tests.unittests.clients.studio.resources.conftest import get_studio_answer, get_studio_chat, get_studio_completion _BASE_URL = "https://test.api.ai21.com/studio/v1" From 49a6ee1bd528b529315d101be0c0cd812839df70 Mon Sep 17 00:00:00 2001 From: github-actions Date: Tue, 2 Jan 2024 14:16:45 +0000 Subject: [PATCH 15/45] chore(release): v2.0.0-rc.7 [skip ci] --- CHANGELOG.md | 25 +++++++++++++++++++++++++ ai21/version.py | 2 +- pyproject.toml | 2 +- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c591603a..583486da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,8 +2,33 @@ +## v2.0.0-rc.7 (2024-01-02) + +### Fix + +* fix: Restructure packages (#31) + +* refactor: moved classes to models package + +* refactor: moved responses = to models package + +* refactor: moved resources to common package + +* refactor: chat message rename + +* refactor: init + +* fix: imports + +* refactor: added more to imports ([`9e8a1f0`](https://github.com/AI21Labs/ai21-python/commit/9e8a1f05dc4afaecfc525fa22cb76594d6cf0c8d)) + + ## v2.0.0-rc.6 (2024-01-02) +### Chore + +* chore(release): v2.0.0-rc.6 [skip ci] ([`1ed334b`](https://github.com/AI21Labs/ai21-python/commit/1ed334b466acca6a8e03c49a387e5918b7a69d45)) + ### Fix * fix: bump version ([`92c3f5d`](https://github.com/AI21Labs/ai21-python/commit/92c3f5de0b35510a3271a389160ab06cc064c7ea)) diff --git a/ai21/version.py b/ai21/version.py index e5fcb364..8c09251f 100644 --- a/ai21/version.py +++ b/ai21/version.py @@ -1 +1 @@ -VERSION = "2.0.0-rc.6" +VERSION = "2.0.0-rc.7" diff --git a/pyproject.toml b/pyproject.toml index 53b2f7a1..23ee2e46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ exclude_lines = [ [tool.poetry] name = "ai21" -version = "2.0.0-rc.6" +version = "2.0.0-rc.7" description = "" authors = ["AI21 Labs"] readme = "README.md" From fa199c4cfb00a2e28a054789801146a09d723fd0 Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Tue, 2 Jan 2024 16:37:07 +0200 Subject: [PATCH 16/45] fix: Added env config class to init (#32) --- ai21/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ai21/__init__.py b/ai21/__init__.py index bfc95fa8..afe7b643 100644 --- a/ai21/__init__.py +++ b/ai21/__init__.py @@ -1,5 +1,6 @@ from typing import Any +from ai21.ai21_env_config import AI21EnvConfig from ai21.clients.studio.ai21_client import AI21Client from ai21.errors import ( AI21APIError, @@ -50,6 +51,7 @@ def __getattr__(name: str) -> Any: __all__ = [ + "AI21EnvConfig", "AI21Client", "AI21APIError", "APITimeoutError", From 6c9c0d02d4df339efac1439a0ef5a0e4e2982587 Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Tue, 2 Jan 2024 16:48:18 +0200 Subject: [PATCH 17/45] fix: Added py.typed (#33) --- ai21/py.typed | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 ai21/py.typed diff --git a/ai21/py.typed b/ai21/py.typed new file mode 100644 index 00000000..e69de29b From 4d4ef7161b156cfed21b010e57140af2c15dc1a4 Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Wed, 3 Jan 2024 10:36:21 +0200 Subject: [PATCH 18/45] fix: Pass env config to client ctor (#34) * fix: removed application and organization * fix: tests --- ai21/ai21_http_client.py | 23 +++++------------------ ai21/clients/studio/ai21_client.py | 11 +++++++---- tests/unittests/test_ai21_http_client.py | 23 ++++------------------- 3 files changed, 16 insertions(+), 41 deletions(-) diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index 995f1f19..d5324f94 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -1,6 +1,5 @@ from typing import Optional, Dict, Any, BinaryIO -from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.errors import MissingApiKeyError from ai21.http_client import HttpClient from ai21.version import VERSION @@ -16,25 +15,19 @@ def __init__( headers: Optional[Dict[str, Any]] = None, timeout_sec: Optional[int] = None, num_retries: Optional[int] = None, - organization: Optional[str] = None, - application: Optional[str] = None, via: Optional[str] = None, http_client: Optional[HttpClient] = None, - env_config: _AI21EnvConfig = AI21EnvConfig, ): - self._env_config = env_config - self._api_key = api_key or self._env_config.api_key + self._api_key = api_key if not self._api_key: raise MissingApiKeyError() - self._api_host = api_host or self._env_config.api_host - self._api_version = api_version or self._env_config.api_version + self._api_host = api_host + self._api_version = api_version self._headers = headers - self._timeout_sec = timeout_sec or self._env_config.timeout_sec - self._num_retries = num_retries or self._env_config.num_retries - self._organization = organization - self._application = application + self._timeout_sec = timeout_sec + self._num_retries = num_retries self._via = via headers = self._build_headers(passed_headers=headers) @@ -69,12 +62,6 @@ def _init_http_client(self, http_client: Optional[HttpClient], headers: Dict[str def _build_user_agent(self) -> str: user_agent = f"ai21 studio SDK {VERSION}" - if self._organization is not None: - user_agent = f"{user_agent} organization: {self._organization}" - - if self._application is not None: - user_agent = f"{user_agent} application: {self._application}" - if self._via is not None: user_agent = f"{user_agent} via: {self._via}" diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index 5f6cee06..dfb781b5 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -1,5 +1,6 @@ from typing import Optional, Any, Dict +from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_chat import StudioChat @@ -35,14 +36,16 @@ def __init__( num_retries: Optional[int] = None, via: Optional[str] = None, http_client: Optional[HttpClient] = None, + env_config: _AI21EnvConfig = AI21EnvConfig, **kwargs, ): self._http_client = AI21HTTPClient( - api_key=api_key, - api_host=api_host, + api_key=api_key or env_config.api_key, + api_host=api_host or env_config.api_host, + api_version=env_config.api_version, headers=headers, - timeout_sec=timeout_sec, - num_retries=num_retries, + timeout_sec=timeout_sec or env_config.timeout_sec, + num_retries=num_retries or env_config.num_retries, via=via, http_client=http_client, ) diff --git a/tests/unittests/test_ai21_http_client.py b/tests/unittests/test_ai21_http_client.py index 67c4f197..1710cc19 100644 --- a/tests/unittests/test_ai21_http_client.py +++ b/tests/unittests/test_ai21_http_client.py @@ -33,27 +33,14 @@ class TestAI21StudioClient: @pytest.mark.parametrize( ids=[ "when_pass_only_via__should_include_via_in_user_agent", - "when_pass_only_application__should_include_application_in_user_agent", - "when_pass_organization__should_include_organization_in_user_agent", - "when_pass_all_user_agent_relevant_params__should_include_them_in_user_agent", ], - argnames=["via", "application", "organization", "expected_user_agent"], + argnames=["via", "expected_user_agent"], argvalues=[ - ("langchain", None, None, f"ai21 studio SDK {VERSION} via: langchain"), - (None, "studio", None, f"ai21 studio SDK {VERSION} application: studio"), - (None, None, "ai21", f"ai21 studio SDK {VERSION} organization: ai21"), - ( - "langchain", - "studio", - "ai21", - f"ai21 studio SDK {VERSION} organization: ai21 application: studio via: langchain", - ), + ("langchain", f"ai21 studio SDK {VERSION} via: langchain"), ], ) - def test__build_headers__user_agent( - self, via: Optional[str], application: Optional[str], organization: Optional[str], expected_user_agent: str - ): - client = AI21HTTPClient(api_key=_DUMMY_API_KEY, via=via, application=application, organization=organization) + def test__build_headers__user_agent(self, via: Optional[str], expected_user_agent: str): + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, via=via) assert client._http_client._headers["User-Agent"] == expected_user_agent def test__build_headers__authorization(self): @@ -67,12 +54,10 @@ def test__build_headers__when_pass_headers__should_append(self): @pytest.mark.parametrize( ids=[ - "when_api_host_is_not_set__should_return_default", "when_api_host_is_set__should_return_set_value", ], argnames=["api_host", "expected_api_host"], argvalues=[ - (None, "https://api.ai21.com/studio/v1"), ("http://test_host", "http://test_host/studio/v1"), ], ) From 4336c46352f61cae80f44e413059250a5fd9c409 Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Wed, 3 Jan 2024 10:39:02 +0200 Subject: [PATCH 19/45] fix: env vars to http client in sagemaker (#35) * fix: env vars to http * fix: env vars to http --- ai21/services/sagemaker.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py index f51e1ae2..468f775a 100644 --- a/ai21/services/sagemaker.py +++ b/ai21/services/sagemaker.py @@ -1,5 +1,6 @@ from typing import List +from ai21 import AI21EnvConfig from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.sagemaker.constants import ( SAGEMAKER_MODEL_PACKAGE_NAMES, @@ -56,7 +57,13 @@ def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: @classmethod def _create_ai21_http_client(cls) -> AI21HTTPClient: - return AI21HTTPClient() + return AI21HTTPClient( + api_key=AI21EnvConfig.api_key, + api_host=AI21EnvConfig.api_host, + api_version=AI21EnvConfig.api_version, + timeout_sec=AI21EnvConfig.timeout_sec, + num_retries=AI21EnvConfig.num_retries, + ) def _assert_model_package_exists(model_name, region): From b766c770c0540cd70657302cea7cc95b869992d6 Mon Sep 17 00:00:00 2001 From: github-actions Date: Wed, 3 Jan 2024 08:39:54 +0000 Subject: [PATCH 20/45] chore(release): v2.0.0-rc.8 [skip ci] --- CHANGELOG.md | 25 +++++++++++++++++++++++++ ai21/version.py | 2 +- pyproject.toml | 2 +- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 583486da..54d920cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,8 +2,33 @@ +## v2.0.0-rc.8 (2024-01-03) + +### Fix + +* fix: env vars to http client in sagemaker (#35) + +* fix: env vars to http + +* fix: env vars to http ([`4336c46`](https://github.com/AI21Labs/ai21-python/commit/4336c46352f61cae80f44e413059250a5fd9c409)) + +* fix: Pass env config to client ctor (#34) + +* fix: removed application and organization + +* fix: tests ([`4d4ef71`](https://github.com/AI21Labs/ai21-python/commit/4d4ef7161b156cfed21b010e57140af2c15dc1a4)) + +* fix: Added py.typed (#33) ([`6c9c0d0`](https://github.com/AI21Labs/ai21-python/commit/6c9c0d02d4df339efac1439a0ef5a0e4e2982587)) + +* fix: Added env config class to init (#32) ([`fa199c4`](https://github.com/AI21Labs/ai21-python/commit/fa199c4cfb00a2e28a054789801146a09d723fd0)) + + ## v2.0.0-rc.7 (2024-01-02) +### Chore + +* chore(release): v2.0.0-rc.7 [skip ci] ([`49a6ee1`](https://github.com/AI21Labs/ai21-python/commit/49a6ee1bd528b529315d101be0c0cd812839df70)) + ### Fix * fix: Restructure packages (#31) diff --git a/ai21/version.py b/ai21/version.py index 8c09251f..8eed6ff1 100644 --- a/ai21/version.py +++ b/ai21/version.py @@ -1 +1 @@ -VERSION = "2.0.0-rc.7" +VERSION = "2.0.0-rc.8" diff --git a/pyproject.toml b/pyproject.toml index 23ee2e46..a6667058 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ exclude_lines = [ [tool.poetry] name = "ai21" -version = "2.0.0-rc.7" +version = "2.0.0-rc.8" description = "" authors = ["AI21 Labs"] readme = "README.md" From 2d0fe725975b5e1c9e817830994cefd91fca0e58 Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Sun, 7 Jan 2024 11:40:33 +0200 Subject: [PATCH 21/45] fix: Removed name parameter from chat message (#36) * fix: removed name parameter from chat message * fix: imports in integration tests --- ai21/models/chat_message.py | 4 +--- examples/studio/answer.py | 2 +- examples/studio/chat.py | 12 ++++-------- examples/studio/completion.py | 2 +- examples/studio/embed.py | 2 +- examples/studio/improvements.py | 2 +- examples/studio/paraphrase.py | 2 +- examples/studio/segmentation.py | 2 +- examples/studio/summarize.py | 2 +- examples/studio/summarize_by_segment.py | 2 +- tests/unittests/clients/studio/resources/conftest.py | 3 +-- 11 files changed, 14 insertions(+), 21 deletions(-) diff --git a/ai21/models/chat_message.py b/ai21/models/chat_message.py index ebf719e3..c7536a77 100644 --- a/ai21/models/chat_message.py +++ b/ai21/models/chat_message.py @@ -1,12 +1,10 @@ from dataclasses import dataclass -from typing import Optional -from ai21.models.role_type import RoleType from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin +from ai21.models.role_type import RoleType @dataclass class ChatMessage(AI21BaseModelMixin): role: RoleType text: str - name: Optional[str] = None diff --git a/examples/studio/answer.py b/examples/studio/answer.py index 4ba86649..10659ddb 100644 --- a/examples/studio/answer.py +++ b/examples/studio/answer.py @@ -1,5 +1,5 @@ from ai21 import AI21Client -from ai21.resources import Mode, AnswerLength +from ai21.models import Mode, AnswerLength client = AI21Client() diff --git a/examples/studio/chat.py b/examples/studio/chat.py index da8efa36..516093d9 100644 --- a/examples/studio/chat.py +++ b/examples/studio/chat.py @@ -1,15 +1,11 @@ from ai21 import AI21Client -from ai21.resources import Message, RoleType, Penalty +from ai21.models import ChatMessage, RoleType, Penalty system = "You're a support engineer in a SaaS company" messages = [ - Message(text="Hello, I need help with a signup process.", role=RoleType.USER, name="Alice"), - Message( - text="Hi Alice, I can help you with that. What seems to be the problem?", role=RoleType.ASSISTANT, name="Bob" - ), - Message( - text="I am having trouble signing up for your product with my Google account.", role=RoleType.USER, name="Alice" - ), + ChatMessage(text="Hello, I need help with a signup process.", role=RoleType.USER), + ChatMessage(text="Hi Alice, I can help you with that. What seems to be the problem?", role=RoleType.ASSISTANT), + ChatMessage(text="I am having trouble signing up for your product with my Google account.", role=RoleType.USER), ] client = AI21Client() diff --git a/examples/studio/completion.py b/examples/studio/completion.py index 333f31a7..b1d0715d 100644 --- a/examples/studio/completion.py +++ b/examples/studio/completion.py @@ -1,5 +1,5 @@ from ai21 import AI21Client -from ai21.resources import Penalty +from ai21.models import Penalty prompt = ( "The following is a conversation between a user of an eCommerce store and a user operation" diff --git a/examples/studio/embed.py b/examples/studio/embed.py index 57166d21..f0dc5a17 100644 --- a/examples/studio/embed.py +++ b/examples/studio/embed.py @@ -1,5 +1,5 @@ from ai21 import AI21Client -from ai21.resources import EmbedType +from ai21.models import EmbedType client = AI21Client() response = client.embed.create( diff --git a/examples/studio/improvements.py b/examples/studio/improvements.py index d1919591..f75dea58 100644 --- a/examples/studio/improvements.py +++ b/examples/studio/improvements.py @@ -1,5 +1,5 @@ from ai21 import AI21Client -from ai21.resources import ImprovementType +from ai21.models import ImprovementType client = AI21Client() response = client.improvements.create( diff --git a/examples/studio/paraphrase.py b/examples/studio/paraphrase.py index f7f643a6..55c3f5c2 100644 --- a/examples/studio/paraphrase.py +++ b/examples/studio/paraphrase.py @@ -1,5 +1,5 @@ from ai21 import AI21Client -from ai21.resources import ParaphraseStyleType +from ai21.models import ParaphraseStyleType client = AI21Client() response = client.paraphrase.create( diff --git a/examples/studio/segmentation.py b/examples/studio/segmentation.py index 93d3a7be..bf2207cb 100644 --- a/examples/studio/segmentation.py +++ b/examples/studio/segmentation.py @@ -1,5 +1,5 @@ from ai21 import AI21Client -from ai21.resources import DocumentType +from ai21.models import DocumentType client = AI21Client() diff --git a/examples/studio/summarize.py b/examples/studio/summarize.py index 73dc9ece..54b9d7af 100644 --- a/examples/studio/summarize.py +++ b/examples/studio/summarize.py @@ -1,5 +1,5 @@ from ai21 import AI21Client -from ai21.resources import DocumentType, SummaryMethod +from ai21.models import DocumentType, SummaryMethod client = AI21Client() response = client.summarize.create( diff --git a/examples/studio/summarize_by_segment.py b/examples/studio/summarize_by_segment.py index 4ab1d82f..e2d2a31d 100644 --- a/examples/studio/summarize_by_segment.py +++ b/examples/studio/summarize_by_segment.py @@ -1,5 +1,5 @@ from ai21 import AI21Client -from ai21.resources import DocumentType +from ai21.models import DocumentType client = AI21Client() response = client.summarize_by_segment.create( diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index a28e559d..8d038dd5 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -42,11 +42,10 @@ def get_studio_answer(): def get_studio_chat(): _DUMMY_MODEL = "dummy-chat-model" _DUMMY_MESSAGES = [ - ChatMessage(text="Hello, I need help with a signup process.", role=RoleType.USER, name="Alice"), + ChatMessage(text="Hello, I need help with a signup process.", role=RoleType.USER), ChatMessage( text="Hi Alice, I can help you with that. What seems to be the problem?", role=RoleType.ASSISTANT, - name="Bob", ), ] _DUMMY_SYSTEM = "You're a support engineer in a SaaS company" From c36a0e40b8a4903071af2ca199a7f42ca2859e3d Mon Sep 17 00:00:00 2001 From: github-actions Date: Sun, 7 Jan 2024 09:41:21 +0000 Subject: [PATCH 22/45] chore(release): v2.0.0-rc.9 [skip ci] --- CHANGELOG.md | 15 +++++++++++++++ ai21/version.py | 2 +- pyproject.toml | 2 +- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 54d920cb..2a13df5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,8 +2,23 @@ +## v2.0.0-rc.9 (2024-01-07) + +### Fix + +* fix: Removed name parameter from chat message (#36) + +* fix: removed name parameter from chat message + +* fix: imports in integration tests ([`2d0fe72`](https://github.com/AI21Labs/ai21-python/commit/2d0fe725975b5e1c9e817830994cefd91fca0e58)) + + ## v2.0.0-rc.8 (2024-01-03) +### Chore + +* chore(release): v2.0.0-rc.8 [skip ci] ([`b766c77`](https://github.com/AI21Labs/ai21-python/commit/b766c770c0540cd70657302cea7cc95b869992d6)) + ### Fix * fix: env vars to http client in sagemaker (#35) diff --git a/ai21/version.py b/ai21/version.py index 8eed6ff1..181566f1 100644 --- a/ai21/version.py +++ b/ai21/version.py @@ -1 +1 @@ -VERSION = "2.0.0-rc.8" +VERSION = "2.0.0-rc.9" diff --git a/pyproject.toml b/pyproject.toml index a6667058..15cee03a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ exclude_lines = [ [tool.poetry] name = "ai21" -version = "2.0.0-rc.8" +version = "2.0.0-rc.9" description = "" authors = ["AI21 Labs"] readme = "README.md" From bac77e9313a8b70ae42c1ec5479a3856ee7f3d91 Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Thu, 18 Jan 2024 22:22:26 +0200 Subject: [PATCH 23/45] fix: chat parameters (#39) * fix: parameters for chat create * fix: imports --- ai21/clients/studio/resources/studio_chat.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index 710ed308..4cbbcf8e 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -1,8 +1,9 @@ -from typing import List, Any, Optional, Dict +from typing import List, Optional from ai21.clients.common.chat_base import Chat from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.models.chat_message import ChatMessage +from ai21.models.penalty import Penalty from ai21.models.responses.chat_response import ChatResponse @@ -20,9 +21,9 @@ def create( top_p: Optional[float] = 1.0, top_k_returns: Optional[int] = 0, stop_sequences: Optional[List[str]] = None, - frequency_penalty: Optional[Dict[str, Any]] = None, - presence_penalty: Optional[Dict[str, Any]] = None, - count_penalty: Optional[Dict[str, Any]] = None, + frequency_penalty: Optional[Penalty] = None, + presence_penalty: Optional[Penalty] = None, + count_penalty: Optional[Penalty] = None, **kwargs, ) -> ChatResponse: body = self._create_body( From 11dd4daca317e858765cabf7c6056f29ff03b261 Mon Sep 17 00:00:00 2001 From: github-actions Date: Thu, 18 Jan 2024 20:23:20 +0000 Subject: [PATCH 24/45] chore(release): v2.0.0-rc.10 [skip ci] --- CHANGELOG.md | 15 +++++++++++++++ ai21/version.py | 2 +- pyproject.toml | 2 +- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a13df5a..3fc65259 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,8 +2,23 @@ +## v2.0.0-rc.10 (2024-01-18) + +### Fix + +* fix: chat parameters (#39) + +* fix: parameters for chat create + +* fix: imports ([`bac77e9`](https://github.com/AI21Labs/ai21-python/commit/bac77e9313a8b70ae42c1ec5479a3856ee7f3d91)) + + ## v2.0.0-rc.9 (2024-01-07) +### Chore + +* chore(release): v2.0.0-rc.9 [skip ci] ([`c36a0e4`](https://github.com/AI21Labs/ai21-python/commit/c36a0e40b8a4903071af2ca199a7f42ca2859e3d)) + ### Fix * fix: Removed name parameter from chat message (#36) diff --git a/ai21/version.py b/ai21/version.py index 181566f1..c36ea4e0 100644 --- a/ai21/version.py +++ b/ai21/version.py @@ -1 +1 @@ -VERSION = "2.0.0-rc.9" +VERSION = "2.0.0-rc.10" diff --git a/pyproject.toml b/pyproject.toml index 15cee03a..57e55d74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ exclude_lines = [ [tool.poetry] name = "ai21" -version = "2.0.0-rc.9" +version = "2.0.0-rc.10" description = "" authors = ["AI21 Labs"] readme = "README.md" From fe3765c800f82b9bd350a851c875c30c45aa41a8 Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Mon, 22 Jan 2024 22:50:18 +0200 Subject: [PATCH 25/45] fix: top_k_returns to top_k_return (#40) --- ai21/clients/common/chat_base.py | 6 +++--- ai21/clients/studio/resources/studio_chat.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ai21/clients/common/chat_base.py b/ai21/clients/common/chat_base.py index 869ddfaa..e73dba3c 100644 --- a/ai21/clients/common/chat_base.py +++ b/ai21/clients/common/chat_base.py @@ -21,7 +21,7 @@ def create( max_tokens: Optional[int] = 300, min_tokens: Optional[int] = 0, top_p: Optional[float] = 1.0, - top_k_returns: Optional[int] = 0, + top_k_return: Optional[int] = 0, stop_sequences: Optional[List[str]] = None, frequency_penalty: Optional[Penalty] = None, presence_penalty: Optional[Penalty] = None, @@ -43,7 +43,7 @@ def _create_body( max_tokens: Optional[int] = 300, min_tokens: Optional[int] = 0, top_p: Optional[float] = 1.0, - top_k_returns: Optional[int] = 0, + top_k_return: Optional[int] = 0, stop_sequences: Optional[List[str]] = None, frequency_penalty: Optional[Penalty] = None, presence_penalty: Optional[Penalty] = None, @@ -58,7 +58,7 @@ def _create_body( "minTokens": min_tokens, "numResults": num_results, "topP": top_p, - "topKReturn": top_k_returns, + "topKReturn": top_k_return, "stopSequences": stop_sequences, "frequencyPenalty": None if frequency_penalty is None else frequency_penalty.to_dict(), "presencePenalty": None if presence_penalty is None else presence_penalty.to_dict(), diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index 4cbbcf8e..76008ba9 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -19,7 +19,7 @@ def create( max_tokens: Optional[int] = 300, min_tokens: Optional[int] = 0, top_p: Optional[float] = 1.0, - top_k_returns: Optional[int] = 0, + top_k_return: Optional[int] = 0, stop_sequences: Optional[List[str]] = None, frequency_penalty: Optional[Penalty] = None, presence_penalty: Optional[Penalty] = None, @@ -35,7 +35,7 @@ def create( max_tokens=max_tokens, min_tokens=min_tokens, top_p=top_p, - top_k_returns=top_k_returns, + top_k_return=top_k_return, stop_sequences=stop_sequences, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, From ec1f9774924f0f44d7449753d364661f5898ed5e Mon Sep 17 00:00:00 2001 From: github-actions Date: Mon, 22 Jan 2024 20:51:05 +0000 Subject: [PATCH 26/45] chore(release): v2.0.0-rc.11 [skip ci] --- CHANGELOG.md | 11 +++++++++++ ai21/version.py | 2 +- pyproject.toml | 2 +- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3fc65259..5f4a0aea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,8 +2,19 @@ +## v2.0.0-rc.11 (2024-01-22) + +### Fix + +* fix: top_k_returns to top_k_return (#40) ([`fe3765c`](https://github.com/AI21Labs/ai21-python/commit/fe3765c800f82b9bd350a851c875c30c45aa41a8)) + + ## v2.0.0-rc.10 (2024-01-18) +### Chore + +* chore(release): v2.0.0-rc.10 [skip ci] ([`11dd4da`](https://github.com/AI21Labs/ai21-python/commit/11dd4daca317e858765cabf7c6056f29ff03b261)) + ### Fix * fix: chat parameters (#39) diff --git a/ai21/version.py b/ai21/version.py index c36ea4e0..e5e933e1 100644 --- a/ai21/version.py +++ b/ai21/version.py @@ -1 +1 @@ -VERSION = "2.0.0-rc.10" +VERSION = "2.0.0-rc.11" diff --git a/pyproject.toml b/pyproject.toml index 57e55d74..4fd87717 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ exclude_lines = [ [tool.poetry] name = "ai21" -version = "2.0.0-rc.10" +version = "2.0.0-rc.11" description = "" authors = ["AI21 Labs"] readme = "README.md" From d7208ab1b3750f87fcd435fd5018c4866e51f748 Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Mon, 29 Jan 2024 15:09:29 +0200 Subject: [PATCH 27/45] fix: added user agent with more details (#42) * test: added user agent with more details * test: changed to capital * test: Removed class into single tests --- ai21/ai21_http_client.py | 5 +- tests/unittests/test_ai21_http_client.py | 200 ++++++++++++----------- 2 files changed, 107 insertions(+), 98 deletions(-) diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index d5324f94..ca7ea5d5 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -1,3 +1,4 @@ +import platform from typing import Optional, Dict, Any, BinaryIO from ai21.errors import MissingApiKeyError @@ -60,7 +61,9 @@ def _init_http_client(self, http_client: Optional[HttpClient], headers: Dict[str return http_client def _build_user_agent(self) -> str: - user_agent = f"ai21 studio SDK {VERSION}" + user_agent = ( + f"AI21 studio SDK {VERSION} Python {platform.python_version()} Operating System {platform.platform()}" + ) if self._via is not None: user_agent = f"{user_agent} via: {self._via}" diff --git a/tests/unittests/test_ai21_http_client.py b/tests/unittests/test_ai21_http_client.py index 1710cc19..86cde2ee 100644 --- a/tests/unittests/test_ai21_http_client.py +++ b/tests/unittests/test_ai21_http_client.py @@ -1,3 +1,4 @@ +import platform from typing import Optional import pytest @@ -7,16 +8,20 @@ from ai21.http_client import HttpClient from ai21.version import VERSION +_EXPECTED_USER_AGENT = ( + f"AI21 studio SDK {VERSION} Python {platform.python_version()} Operating System {platform.platform()}" +) + _DUMMY_API_KEY = "dummy_key" _EXPECTED_GET_HEADERS = { "Authorization": "Bearer dummy_key", "Content-Type": "application/json", - "User-Agent": f"ai21 studio SDK {VERSION}", + "User-Agent": _EXPECTED_USER_AGENT, } _EXPECTED_POST_FILE_HEADERS = { "Authorization": "Bearer dummy_key", - "User-Agent": f"ai21 studio SDK {VERSION}", + "User-Agent": _EXPECTED_USER_AGENT, } @@ -29,99 +34,100 @@ def json(self): return self.json_data -class TestAI21StudioClient: - @pytest.mark.parametrize( - ids=[ - "when_pass_only_via__should_include_via_in_user_agent", - ], - argnames=["via", "expected_user_agent"], - argvalues=[ - ("langchain", f"ai21 studio SDK {VERSION} via: langchain"), - ], - ) - def test__build_headers__user_agent(self, via: Optional[str], expected_user_agent: str): - client = AI21HTTPClient(api_key=_DUMMY_API_KEY, via=via) - assert client._http_client._headers["User-Agent"] == expected_user_agent - - def test__build_headers__authorization(self): - client = AI21HTTPClient(api_key=_DUMMY_API_KEY) - assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" - - def test__build_headers__when_pass_headers__should_append(self): - client = AI21HTTPClient(api_key=_DUMMY_API_KEY, headers={"foo": "bar"}) - assert client._http_client._headers["foo"] == "bar" - assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" - - @pytest.mark.parametrize( - ids=[ - "when_api_host_is_set__should_return_set_value", - ], - argnames=["api_host", "expected_api_host"], - argvalues=[ - ("http://test_host", "http://test_host/studio/v1"), - ], - ) - def test__get_base_url(self, api_host: Optional[str], expected_api_host: str): - client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") - assert client.get_base_url() == expected_api_host - - @pytest.mark.parametrize( - ids=[ - "when_making_request__should_send_appropriate_parameters", - "when_making_request_with_files__should_send_appropriate_post_request", - ], - argnames=["params", "headers"], - argvalues=[ - ({"method": "GET", "url": "test_url", "params": {"foo": "bar"}}, _EXPECTED_GET_HEADERS), - ( - {"method": "POST", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}}, - _EXPECTED_POST_FILE_HEADERS, - ), - ], - ) - def test__execute_http_request__( - self, - params, - headers, - dummy_api_host: str, - mock_requests_session: requests.Session, - ): - response_json = {"test_key": "test_value"} - mock_requests_session.request.return_value = MockResponse(response_json, 200) - - http_client = HttpClient(session=mock_requests_session) - client = AI21HTTPClient( - http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1" - ) - - response = client.execute_http_request(**params) - assert response == response_json - - if "files" in params: - # We split it because when calling requests with "files", "params" is turned into "data" - mock_requests_session.request.assert_called_once_with( - timeout=300, - headers=headers, - files=params["files"], - data=params["params"], - url=params["url"], - method=params["method"], - ) - else: - mock_requests_session.request.assert_called_once_with(timeout=300, headers=headers, **params) - - def test__execute_http_request__when_files_with_put_method__should_raise_value_error( - self, - dummy_api_host: str, - mock_requests_session: requests.Session, - ): - response_json = {"test_key": "test_value"} - http_client = HttpClient(session=mock_requests_session) - client = AI21HTTPClient( - http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1" +@pytest.mark.parametrize( + ids=[ + "when_pass_only_via__should_include_via_in_user_agent", + ], + argnames=["via", "expected_user_agent"], + argvalues=[ + ( + "langchain", + f"{_EXPECTED_USER_AGENT} via: langchain", + ), + ], +) +def test__build_headers__user_agent(via: Optional[str], expected_user_agent: str): + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, via=via) + assert client._http_client._headers["User-Agent"] == expected_user_agent + + +def test__build_headers__authorization(): + client = AI21HTTPClient(api_key=_DUMMY_API_KEY) + assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" + + +def test__build_headers__when_pass_headers__should_append(): + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, headers={"foo": "bar"}) + assert client._http_client._headers["foo"] == "bar" + assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" + + +@pytest.mark.parametrize( + ids=[ + "when_api_host_is_set__should_return_set_value", + ], + argnames=["api_host", "expected_api_host"], + argvalues=[ + ("http://test_host", "http://test_host/studio/v1"), + ], +) +def test__get_base_url(api_host: Optional[str], expected_api_host: str): + client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") + assert client.get_base_url() == expected_api_host + + +@pytest.mark.parametrize( + ids=[ + "when_making_request__should_send_appropriate_parameters", + "when_making_request_with_files__should_send_appropriate_post_request", + ], + argnames=["params", "headers"], + argvalues=[ + ({"method": "GET", "url": "test_url", "params": {"foo": "bar"}}, _EXPECTED_GET_HEADERS), + ( + {"method": "POST", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}}, + _EXPECTED_POST_FILE_HEADERS, + ), + ], +) +def test__execute_http_request__( + params, + headers, + dummy_api_host: str, + mock_requests_session: requests.Session, +): + response_json = {"test_key": "test_value"} + mock_requests_session.request.return_value = MockResponse(response_json, 200) + + http_client = HttpClient(session=mock_requests_session) + client = AI21HTTPClient(http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") + + response = client.execute_http_request(**params) + assert response == response_json + + if "files" in params: + # We split it because when calling requests with "files", "params" is turned into "data" + mock_requests_session.request.assert_called_once_with( + timeout=300, + headers=headers, + files=params["files"], + data=params["params"], + url=params["url"], + method=params["method"], ) - - mock_requests_session.request.return_value = MockResponse(response_json, 200) - with pytest.raises(ValueError): - params = {"method": "PUT", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}} - client.execute_http_request(**params) + else: + mock_requests_session.request.assert_called_once_with(timeout=300, headers=headers, **params) + + +def test__execute_http_request__when_files_with_put_method__should_raise_value_error( + dummy_api_host: str, + mock_requests_session: requests.Session, +): + response_json = {"test_key": "test_value"} + http_client = HttpClient(session=mock_requests_session) + client = AI21HTTPClient(http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") + + mock_requests_session.request.return_value = MockResponse(response_json, 200) + with pytest.raises(ValueError): + params = {"method": "PUT", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}} + client.execute_http_request(**params) From 0eacdbb07e62ed415e6c8cfe0cd02b2296c63eb4 Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Mon, 29 Jan 2024 17:54:53 +0200 Subject: [PATCH 28/45] ci: Add rc branch prefix trigger for integration tests (#43) * ci: rc branch trigger for integration test * fix: wrapped in quotes --- .github/workflows/integration-tests.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index 83291bbc..ee23e90d 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -4,6 +4,7 @@ on: push: branches: - main + - "rc_*" env: POETRY_VERSION: "1.7.1" From 127cef460d295b7d7a3af83848f30403cf04ba4b Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Tue, 30 Jan 2024 09:57:09 +0200 Subject: [PATCH 29/45] fix: aws tests (#44) * ci: rc branch trigger for integration test * fix: wrapped in quotes * fix: AWS tests * test: ci * fix: AWS tests * test: ci * fix: Removed testing pattern for tests --- .github/workflows/integration-tests.yaml | 2 +- ai21/clients/bedrock/resources/bedrock_completion.py | 2 +- ai21/clients/sagemaker/resources/sagemaker_completion.py | 4 ++-- examples/bedrock/completion.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index ee23e90d..6b12cbde 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -34,7 +34,7 @@ jobs: poetry env use ${{ matrix.python-version }} - name: Install dependencies run: | - poetry install --no-root --only dev --all-extras + poetry install --only dev --all-extras - name: Lint Python (Black) run: | poetry run inv formatter diff --git a/ai21/clients/bedrock/resources/bedrock_completion.py b/ai21/clients/bedrock/resources/bedrock_completion.py index f9ff4646..3f1418e3 100644 --- a/ai21/clients/bedrock/resources/bedrock_completion.py +++ b/ai21/clients/bedrock/resources/bedrock_completion.py @@ -1,7 +1,7 @@ from typing import Optional, List -from ai21.resources import Penalty from ai21.clients.bedrock.resources.bedrock_resource import BedrockResource +from ai21.models import Penalty from ai21.models.responses.completion_response import CompletionsResponse diff --git a/ai21/clients/sagemaker/resources/sagemaker_completion.py b/ai21/clients/sagemaker/resources/sagemaker_completion.py index 373fdcb8..7aeddfaa 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_completion.py +++ b/ai21/clients/sagemaker/resources/sagemaker_completion.py @@ -1,8 +1,8 @@ from typing import Optional, List -from ai21.resources import Penalty -from ai21.models.responses.completion_response import CompletionsResponse from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource +from ai21.models import Penalty +from ai21.models.responses.completion_response import CompletionsResponse class SageMakerCompletion(SageMakerResource): diff --git a/examples/bedrock/completion.py b/examples/bedrock/completion.py index 8002e258..def48813 100644 --- a/examples/bedrock/completion.py +++ b/examples/bedrock/completion.py @@ -1,5 +1,5 @@ from ai21 import AI21BedrockClient, BedrockModelID -from ai21.resources import Penalty +from ai21.models import Penalty # Bedrock is currently supported only in us-east-1 region. # Either set your profile's region to us-east-1 or uncomment next line From 78709a7ab195fe5cf21b03d646e6e7afb6775289 Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Tue, 30 Jan 2024 10:50:14 +0200 Subject: [PATCH 30/45] fix: Integration tests (#41) * fix: types * test: Added some integration tests * test: improvements * test: test_paraphrase.py * fix: doc * fix: removed unused comment * test: test_summarize.py * test: Added tests for test_summarize_by_segment.py * test: test_segmentation.py * fix: file id in library response * fix: example for library * ci: Add rc branch prefix trigger for integration tests (#43) * ci: rc branch trigger for integration test * fix: wrapped in quotes * fix: types * test: Added some integration tests * test: improvements * test: test_paraphrase.py * fix: doc * fix: removed unused comment * test: test_summarize.py * test: Added tests for test_summarize_by_segment.py * test: test_segmentation.py * fix: file id in library response * fix: example for library * docs: docstrings * fix: question * fix: CR * test: Added tests to segment type in embed --- ai21/clients/common/answer_base.py | 11 +- ai21/clients/common/chat_base.py | 19 +++ ai21/clients/common/completion_base.py | 18 +++ ai21/clients/common/custom_model_base.py | 10 ++ ai21/clients/common/dataset_base.py | 11 ++ ai21/clients/common/embed_base.py | 8 ++ ai21/clients/common/gec_base.py | 6 + ai21/clients/common/improvements_base.py | 7 ++ ai21/clients/common/paraphrase_base.py | 10 ++ ai21/clients/common/segmentation_base.py | 7 ++ ai21/clients/common/summarize_base.py | 8 ++ .../common/summarize_by_segment_base.py | 8 ++ ai21/clients/studio/resources/studio_embed.py | 3 +- .../studio/resources/studio_summarize.py | 1 - .../responses/library_answer_response.py | 2 +- examples/studio/library.py | 8 +- examples/studio/library_answer.py | 2 +- .../clients/studio/__init__.py | 0 .../clients/studio/test_answer.py | 38 ++++++ .../clients/studio/test_chat.py | 94 +++++++++++++++ .../clients/studio/test_completion.py | 112 ++++++++++++++++++ .../clients/studio/test_embed.py | 37 ++++++ .../clients/studio/test_gec.py | 31 +++++ .../clients/studio/test_improvements.py | 13 ++ .../clients/studio/test_paraphrase.py | 51 ++++++++ .../clients/studio/test_segmentation.py | 55 +++++++++ .../clients/studio/test_summarize.py | 62 ++++++++++ .../studio/test_summarize_by_segment.py | 66 +++++++++++ 28 files changed, 692 insertions(+), 6 deletions(-) create mode 100644 tests/integration_tests/clients/studio/__init__.py create mode 100644 tests/integration_tests/clients/studio/test_answer.py create mode 100644 tests/integration_tests/clients/studio/test_chat.py create mode 100644 tests/integration_tests/clients/studio/test_completion.py create mode 100644 tests/integration_tests/clients/studio/test_embed.py create mode 100644 tests/integration_tests/clients/studio/test_gec.py create mode 100644 tests/integration_tests/clients/studio/test_improvements.py create mode 100644 tests/integration_tests/clients/studio/test_paraphrase.py create mode 100644 tests/integration_tests/clients/studio/test_segmentation.py create mode 100644 tests/integration_tests/clients/studio/test_summarize.py create mode 100644 tests/integration_tests/clients/studio/test_summarize_by_segment.py diff --git a/ai21/clients/common/answer_base.py b/ai21/clients/common/answer_base.py index a1543646..1821db9a 100644 --- a/ai21/clients/common/answer_base.py +++ b/ai21/clients/common/answer_base.py @@ -17,6 +17,15 @@ def create( mode: Optional[Mode] = None, **kwargs, ) -> AnswerResponse: + """ + + :param context: A string containing the document context for which the question will be answered + :param question: A string containing the question to be answered based on the provided context. + :param answer_length: Approximate length of the answer in words. + :param mode: + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> AnswerResponse: @@ -26,7 +35,7 @@ def _create_body( self, context: str, question: str, - answer_length: Optional[str], + answer_length: Optional[AnswerLength], mode: Optional[str], ) -> Dict[str, Any]: return {"context": context, "question": question, "answerLength": answer_length, "mode": mode} diff --git a/ai21/clients/common/chat_base.py b/ai21/clients/common/chat_base.py index e73dba3c..03037491 100644 --- a/ai21/clients/common/chat_base.py +++ b/ai21/clients/common/chat_base.py @@ -28,6 +28,25 @@ def create( count_penalty: Optional[Penalty] = None, **kwargs, ) -> ChatResponse: + """ + + :param model: model type you wish to interact with + :param messages: A sequence of messages ingested by the model, which then returns the assistant's response + :param system: Offers the model overarching guidance on its response approach, encapsulating context, tone, + guardrails, and more + :param max_tokens: The maximum number of tokens to generate per result + :param num_results: Number of completions to sample and return. + :param min_tokens: The minimum number of tokens to generate per result. + :param temperature: A value controlling the "creativity" of the model's responses. + :param top_p: A value controlling the diversity of the model's responses. + :param top_k_return: The number of top-scoring tokens to consider for each generation step. + :param stop_sequences: Stops decoding if any of the strings is generated + :param frequency_penalty: A penalty applied to tokens that are frequently generated. + :param presence_penalty: A penalty applied to tokens that are already present in the prompt. + :param count_penalty: A penalty applied to tokens based on their frequency in the generated responses + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> ChatResponse: diff --git a/ai21/clients/common/completion_base.py b/ai21/clients/common/completion_base.py index a2bc8c3d..06fef7d8 100644 --- a/ai21/clients/common/completion_base.py +++ b/ai21/clients/common/completion_base.py @@ -28,6 +28,24 @@ def create( epoch: Optional[int] = None, **kwargs, ) -> CompletionsResponse: + """ + :param model: model type you wish to interact with + :param prompt: Text for model to complete + :param max_tokens: The maximum number of tokens to generate per result + :param num_results: Number of completions to sample and return. + :param min_tokens: The minimum number of tokens to generate per result. + :param temperature: A value controlling the "creativity" of the model's responses. + :param top_p: A value controlling the diversity of the model's responses. + :param top_k_return: The number of top-scoring tokens to consider for each generation step. + :param custom_model: + :param stop_sequences: Stops decoding if any of the strings is generated + :param frequency_penalty: A penalty applied to tokens that are frequently generated. + :param presence_penalty: A penalty applied to tokens that are already present in the prompt. + :param count_penalty: A penalty applied to tokens based on their frequency in the generated responses + :param epoch: + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> CompletionsResponse: diff --git a/ai21/clients/common/custom_model_base.py b/ai21/clients/common/custom_model_base.py index 7d3b55ae..303b1adf 100644 --- a/ai21/clients/common/custom_model_base.py +++ b/ai21/clients/common/custom_model_base.py @@ -18,6 +18,16 @@ def create( num_epochs: Optional[int] = None, **kwargs, ) -> None: + """ + + :param dataset_id: The dataset you want to train your model on. + :param model_name: The name of your trained model + :param model_type: The type of model to train. + :param learning_rate: The learning rate used for training. + :param num_epochs: Number of epochs for training + :param kwargs: + :return: + """ pass @abstractmethod diff --git a/ai21/clients/common/dataset_base.py b/ai21/clients/common/dataset_base.py index 9fa57f85..732dee39 100644 --- a/ai21/clients/common/dataset_base.py +++ b/ai21/clients/common/dataset_base.py @@ -19,6 +19,17 @@ def create( split_ratio: Optional[float] = None, **kwargs, ): + """ + + :param file_path: Local path to dataset + :param dataset_name: Dataset name. Must be unique + :param selected_columns: Mapping of the columns in the dataset file to prompt and completion columns. + :param approve_whitespace_correction: Automatically correct examples that violate best practices + :param delete_long_rows: Allow removal of examples where prompt + completion lengths exceeds 2047 tokens + :param split_ratio: + :param kwargs: + :return: + """ pass @abstractmethod diff --git a/ai21/clients/common/embed_base.py b/ai21/clients/common/embed_base.py index baadd4ec..aaf9363e 100644 --- a/ai21/clients/common/embed_base.py +++ b/ai21/clients/common/embed_base.py @@ -10,6 +10,14 @@ class Embed(ABC): @abstractmethod def create(self, texts: List[str], *, type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse: + """ + + :param texts: A list of strings, each representing a document or segment of text to be embedded. + :param type: For retrieval/search use cases, indicates whether the texts that were + sent are segments or the query. + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> EmbedResponse: diff --git a/ai21/clients/common/gec_base.py b/ai21/clients/common/gec_base.py index 8de743e2..091e6427 100644 --- a/ai21/clients/common/gec_base.py +++ b/ai21/clients/common/gec_base.py @@ -9,6 +9,12 @@ class GEC(ABC): @abstractmethod def create(self, text: str, **kwargs) -> GECResponse: + """ + + :param text: The input text to be corrected. + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> GECResponse: diff --git a/ai21/clients/common/improvements_base.py b/ai21/clients/common/improvements_base.py index df912e1d..df13fe58 100644 --- a/ai21/clients/common/improvements_base.py +++ b/ai21/clients/common/improvements_base.py @@ -10,6 +10,13 @@ class Improvements(ABC): @abstractmethod def create(self, text: str, types: List[ImprovementType], **kwargs) -> ImprovementsResponse: + """ + + :param text: The input text to be improved. + :param types: Types of improvements to apply. + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> ImprovementsResponse: diff --git a/ai21/clients/common/paraphrase_base.py b/ai21/clients/common/paraphrase_base.py index 917cdd75..3c01bbb7 100644 --- a/ai21/clients/common/paraphrase_base.py +++ b/ai21/clients/common/paraphrase_base.py @@ -18,6 +18,16 @@ def create( end_index: Optional[int] = None, **kwargs, ) -> ParaphraseResponse: + """ + + :param text: The input text to be paraphrased. + :param style: Controls length and tone + :param start_index: Specifies the starting position of the paraphrasing process in the given text + :param end_index: specifies the position of the last character to be paraphrased, including the character + following it. If the parameter is not provided, the default value is set to the length of the given text. + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> ParaphraseResponse: diff --git a/ai21/clients/common/segmentation_base.py b/ai21/clients/common/segmentation_base.py index c4f658a9..97c74104 100644 --- a/ai21/clients/common/segmentation_base.py +++ b/ai21/clients/common/segmentation_base.py @@ -10,6 +10,13 @@ class Segmentation(ABC): @abstractmethod def create(self, source: str, source_type: DocumentType, **kwargs) -> SegmentationResponse: + """ + + :param source: Raw input text, or URL of a web page. + :param source_type: The type of the source - either TEXT or URL. + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> SegmentationResponse: diff --git a/ai21/clients/common/summarize_base.py b/ai21/clients/common/summarize_base.py index 85cec2f5..2a70dfcd 100644 --- a/ai21/clients/common/summarize_base.py +++ b/ai21/clients/common/summarize_base.py @@ -16,6 +16,14 @@ def create( summary_method: Optional[SummaryMethod] = None, **kwargs, ) -> SummarizeResponse: + """ + :param source: The input text, or URL of a web page to be summarized. + :param source_type: Either TEXT or URL + :param focus: Summaries focused on a topic of your choice. + :param summary_method: + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> SummarizeResponse: diff --git a/ai21/clients/common/summarize_by_segment_base.py b/ai21/clients/common/summarize_by_segment_base.py index 40a4abfa..236337de 100644 --- a/ai21/clients/common/summarize_by_segment_base.py +++ b/ai21/clients/common/summarize_by_segment_base.py @@ -19,6 +19,14 @@ def create( focus: Optional[str] = None, **kwargs, ) -> SummarizeBySegmentResponse: + """ + + :param source: The input text, or URL of a web page to be summarized. + :param source_type: Either TEXT or URL + :param focus: Summaries focused on a topic of your choice. + :param kwargs: + :return: + """ pass def _json_to_response(self, json: Dict[str, Any]) -> SummarizeBySegmentResponse: diff --git a/ai21/clients/studio/resources/studio_embed.py b/ai21/clients/studio/resources/studio_embed.py index 7495ea67..7e7c8fad 100644 --- a/ai21/clients/studio/resources/studio_embed.py +++ b/ai21/clients/studio/resources/studio_embed.py @@ -2,11 +2,12 @@ from ai21.clients.common.embed_base import Embed from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models.embed_type import EmbedType from ai21.models.responses.embed_response import EmbedResponse class StudioEmbed(StudioResource, Embed): - def create(self, texts: List[str], type: Optional[str] = None, **kwargs) -> EmbedResponse: + def create(self, texts: List[str], type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse: url = f"{self._client.get_base_url()}/{self._module_name}" body = self._create_body(texts=texts, type=type) response = self._post(url=url, body=body) diff --git a/ai21/clients/studio/resources/studio_summarize.py b/ai21/clients/studio/resources/studio_summarize.py index b2b5f860..4180ff52 100644 --- a/ai21/clients/studio/resources/studio_summarize.py +++ b/ai21/clients/studio/resources/studio_summarize.py @@ -18,7 +18,6 @@ def create( summary_method: Optional[SummaryMethod] = None, **kwargs, ) -> SummarizeResponse: - # Make a summarize request to the AI21 API. Returns the response either as a string or a AI21Summarize object. body = self._create_body( source=source, source_type=source_type, diff --git a/ai21/models/responses/library_answer_response.py b/ai21/models/responses/library_answer_response.py index 36341eda..28fab165 100644 --- a/ai21/models/responses/library_answer_response.py +++ b/ai21/models/responses/library_answer_response.py @@ -6,7 +6,7 @@ @dataclass class SourceDocument(AI21BaseModelMixin): - field_id: str + file_id: str name: str highlights: List[str] public_url: Optional[str] = None diff --git a/examples/studio/library.py b/examples/studio/library.py index d693d697..ca8e4840 100644 --- a/examples/studio/library.py +++ b/examples/studio/library.py @@ -22,7 +22,12 @@ def validate_file_deleted(): file_path = os.getcwd() path = os.path.join(file_path, file_name) -file_utils.create_file(file_path, file_name, content="test content" * 100) +_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the +Netherlands. From the 10th to the 16th century, Holland proper was a unified political + region within the Holy Roman Empire as a county ruled by the counts of Holland. + By the 17th century, the province of Holland had risen to become a maritime and economic power, + dominating the other provinces of the newly independent Dutch Republic.""" +file_utils.create_file(file_path, file_name, content=_SOURCE_TEXT) file_id = client.library.files.create( file_path=path, @@ -31,6 +36,7 @@ def validate_file_deleted(): public_url="www.example.com", ) print(file_id) + files = client.library.files.list() print(files) uploaded_file = client.library.files.get(file_id) diff --git a/examples/studio/library_answer.py b/examples/studio/library_answer.py index 54b2bb1d..20d46402 100644 --- a/examples/studio/library_answer.py +++ b/examples/studio/library_answer.py @@ -2,5 +2,5 @@ client = AI21Client() -response = client.library.answer.create(question="Where is Thailand?") +response = client.library.answer.create(question="Can you tell me something about Holland?") print(response) diff --git a/tests/integration_tests/clients/studio/__init__.py b/tests/integration_tests/clients/studio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/clients/studio/test_answer.py b/tests/integration_tests/clients/studio/test_answer.py new file mode 100644 index 00000000..51dd4fa2 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_answer.py @@ -0,0 +1,38 @@ +import pytest +from ai21 import AI21Client +from ai21.models import AnswerLength, Mode + +_CONTEXT = ( + "Holland is a geographical region[2] and former province on the western coast of" + " the Netherlands. From the " + "10th to the 16th century, Holland proper was a unified political region within the Holy Roman Empire as a county " + "ruled by the counts of Holland. By the 17th century, the province of Holland had risen to become a maritime and " + "economic power, dominating the other provinces of the newly independent Dutch Republic." +) + + +@pytest.mark.parametrize( + ids=[ + "when_answer_is_in_context", + "when_answer_not_in_context", + ], + argnames=["question", "is_answer_in_context", "expected_answer_type"], + argvalues=[ + ("When did Holland become an economic power?", True, str), + ("Is the ocean blue?", False, None), + ], +) +def test_answer(question: str, is_answer_in_context: bool, expected_answer_type: type): + client = AI21Client() + response = client.answer.create( + context=_CONTEXT, + question=question, + answer_length=AnswerLength.LONG, + mode=Mode.FLEXIBLE, + ) + + assert response.answer_in_context == is_answer_in_context + if is_answer_in_context: + assert isinstance(response.answer, str) + else: + assert response.answer is None diff --git a/tests/integration_tests/clients/studio/test_chat.py b/tests/integration_tests/clients/studio/test_chat.py new file mode 100644 index 00000000..70d26761 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_chat.py @@ -0,0 +1,94 @@ +import pytest + +from ai21 import AI21Client +from ai21.models import ChatMessage, RoleType, Penalty, FinishReason + +_MODEL = "j2-ultra" +_MESSAGES = [ + ChatMessage( + text="Hello, I need help studying for the coming test, can you teach me about the US constitution? ", + role=RoleType.USER, + ), +] +_SYSTEM = "You are a teacher in a public school" + + +def test_chat(): + num_results = 5 + messages = _MESSAGES + + client = AI21Client() + response = client.chat.create( + system=_SYSTEM, + messages=messages, + num_results=num_results, + max_tokens=64, + temperature=0.7, + min_tokens=1, + stop_sequences=["\n"], + top_p=0.3, + top_k_return=0, + model=_MODEL, + count_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + frequency_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + presence_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + ) + + assert response.outputs[0].role == RoleType.ASSISTANT + assert isinstance(response.outputs[0].text, str) + assert response.outputs[0].finish_reason == FinishReason(reason="endoftext") + + assert len(response.outputs) == num_results + + +@pytest.mark.parametrize( + ids=[ + "finish_reason_length", + "finish_reason_endoftext", + "finish_reason_stop_sequence", + ], + argnames=["max_tokens", "stop_sequences", "reason"], + argvalues=[ + (2, "##", "length"), + (100, "##", "endoftext"), + (20, ".", "stop"), + ], +) +def test_chat_when_finish_reason_defined__should_halt_on_expected_reason( + max_tokens: int, stop_sequences: str, reason: str +): + client = AI21Client() + response = client.chat.create( + messages=_MESSAGES, + system=_SYSTEM, + max_tokens=max_tokens, + model="j2-ultra", + temperature=1, + top_p=0, + num_results=1, + stop_sequences=[stop_sequences], + top_k_return=0, + ) + + assert response.outputs[0].finish_reason.reason == reason diff --git a/tests/integration_tests/clients/studio/test_completion.py b/tests/integration_tests/clients/studio/test_completion.py new file mode 100644 index 00000000..9938fa2a --- /dev/null +++ b/tests/integration_tests/clients/studio/test_completion.py @@ -0,0 +1,112 @@ +import pytest + +from ai21 import AI21Client +from ai21.models import Penalty + +_PROMPT = """ +User: Haven't received a confirmation email for my order #12345. +Assistant: I'm sorry to hear that. I'll look into it right away. +User: Can you please let me know when I can expect to receive it? +""" + + +def test_completion(): + num_results = 3 + + client = AI21Client() + response = client.completion.create( + prompt=_PROMPT, + max_tokens=64, + model="j2-ultra", + temperature=0.7, + top_p=0.2, + top_k_return=0.2, + stop_sequences=["##"], + num_results=num_results, + custom_model=None, + epoch=1, + count_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + frequency_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + presence_penalty=Penalty( + scale=0, + apply_to_emojis=False, + apply_to_numbers=False, + apply_to_stopwords=False, + apply_to_punctuation=False, + apply_to_whitespaces=False, + ), + ) + + assert response.prompt.text == _PROMPT + assert len(response.completions) == num_results + # Check the results aren't all the same + assert len([completion.data.text for completion in response.completions]) == num_results + for completion in response.completions: + assert isinstance(completion.data.text, str) + + +def test_completion_when_temperature_1_and_top_p_is_0__should_return_same_response(): + num_results = 5 + + client = AI21Client() + response = client.completion.create( + prompt=_PROMPT, + max_tokens=64, + model="j2-ultra", + temperature=1, + top_p=0, + top_k_return=0, + num_results=num_results, + epoch=1, + ) + + assert response.prompt.text == _PROMPT + assert len(response.completions) == num_results + # Verify all results are the same + assert len(set([completion.data.text for completion in response.completions])) == 1 + + +@pytest.mark.parametrize( + ids=[ + "finish_reason_length", + "finish_reason_endoftext", + "finish_reason_stop_sequence", + ], + argnames=["max_tokens", "stop_sequences", "reason"], + argvalues=[ + (10, "##", "length"), + (100, "##", "endoftext"), + (50, "\n", "stop"), + ], +) +def test_completion_when_finish_reason_defined__should_halt_on_expected_reason( + max_tokens: int, stop_sequences: str, reason: str +): + client = AI21Client() + response = client.completion.create( + prompt=_PROMPT, + max_tokens=max_tokens, + model="j2-ultra", + temperature=1, + top_p=0, + num_results=1, + stop_sequences=[stop_sequences], + top_k_return=0, + epoch=1, + ) + + assert response.completions[0].finish_reason.reason == reason diff --git a/tests/integration_tests/clients/studio/test_embed.py b/tests/integration_tests/clients/studio/test_embed.py new file mode 100644 index 00000000..8fc77dc5 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_embed.py @@ -0,0 +1,37 @@ +from typing import List + +import pytest +from ai21 import AI21Client +from ai21.models import EmbedType + +_TEXT_0 = "Holland is a geographical region and former province on the western coast of the Netherlands." +_TEXT_1 = "Germany is a country in Central Europe. It is the second-most populous country in Europe after Russia" + +_SEGMENT_0 = "The sun sets behind the mountains," +_SEGMENT_1 = "casting a warm glow over" +_SEGMENT_2 = "the city of Amsterdam." + + +@pytest.mark.parametrize( + ids=[ + "when_single_text_and_query__should_return_single_embedding", + "when_multiple_text_and_query__should_return_multiple_embeddings", + "when_single_text_and_segment__should_return_single_embedding", + "when_multiple_text_and_segment__should_return_multiple_embeddings", + ], + argnames=["texts", "type"], + argvalues=[ + ([_TEXT_0], EmbedType.QUERY), + ([_TEXT_0, _TEXT_1], EmbedType.QUERY), + ([_SEGMENT_0], EmbedType.SEGMENT), + ([_SEGMENT_0, _SEGMENT_1, _SEGMENT_2], EmbedType.SEGMENT), + ], +) +def test_embed(texts: List[str], type: EmbedType): + client = AI21Client() + response = client.embed.create( + texts=texts, + type=type, + ) + + assert len(response.results) == len(texts) diff --git a/tests/integration_tests/clients/studio/test_gec.py b/tests/integration_tests/clients/studio/test_gec.py new file mode 100644 index 00000000..51418cf1 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_gec.py @@ -0,0 +1,31 @@ +import pytest +from ai21 import AI21Client +from ai21.models import CorrectionType + + +@pytest.mark.parametrize( + ids=[ + "should_fix_spelling", + "should_fix_grammar", + "should_fix_missing_word", + "should_fix_punctuation", + "should_fix_wrong_word", + ], + argnames=["text", "correction_type", "expected_suggestion"], + argvalues=[ + ("jazzz is music", CorrectionType.SPELLING, "Jazz"), + ("You am nice", CorrectionType.GRAMMAR, "are"), + ( + "He stared out the window, lost in thought, as the raindrops against the glass.", + CorrectionType.MISSING_WORD, + "raindrops fell against", + ), + ("He is a well known author.", CorrectionType.PUNCTUATION, "well-known"), + ("He is a dog-known author.", CorrectionType.WRONG_WORD, "well-known"), + ], +) +def test_gec(text: str, correction_type: CorrectionType, expected_suggestion: str): + client = AI21Client() + response = client.gec.create(text=text) + assert response.corrections[0].suggestion == expected_suggestion + assert response.corrections[0].correction_type == correction_type diff --git a/tests/integration_tests/clients/studio/test_improvements.py b/tests/integration_tests/clients/studio/test_improvements.py new file mode 100644 index 00000000..3488f781 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_improvements.py @@ -0,0 +1,13 @@ +from ai21 import AI21Client +from ai21.models import ImprovementType + + +def test_improvements(): + client = AI21Client() + response = client.improvements.create( + text="Affiliated with the profession of project management," + " I have ameliorated myself with a different set of hard skills as well as soft skills.", + types=[ImprovementType.FLUENCY], + ) + + assert len(response.improvements) > 0 diff --git a/tests/integration_tests/clients/studio/test_paraphrase.py b/tests/integration_tests/clients/studio/test_paraphrase.py new file mode 100644 index 00000000..a7ba93aa --- /dev/null +++ b/tests/integration_tests/clients/studio/test_paraphrase.py @@ -0,0 +1,51 @@ +import pytest + +from ai21 import AI21Client +from ai21.models import ParaphraseStyleType + + +def test_paraphrase(): + client = AI21Client() + response = client.paraphrase.create( + text="The cat (Felis catus) is a domestic species of small carnivorous mammal", + style=ParaphraseStyleType.FORMAL, + start_index=0, + end_index=20, + ) + for suggestion in response.suggestions: + print(suggestion.text) + assert len(response.suggestions) > 0 + + +def test_paraphrase__when_start_and_end_index_is_small__should_not_return_suggestions(): + client = AI21Client() + response = client.paraphrase.create( + text="The cat (Felis catus) is a domestic species of small carnivorous mammal", + style=ParaphraseStyleType.GENERAL, + start_index=0, + end_index=5, + ) + assert len(response.suggestions) == 0 + + +@pytest.mark.parametrize( + ids=["when_general", "when_casual", "when_long", "when_short", "when_formal"], + argnames=["style"], + argvalues=[ + (ParaphraseStyleType.GENERAL,), + (ParaphraseStyleType.CASUAL,), + (ParaphraseStyleType.LONG,), + (ParaphraseStyleType.SHORT,), + (ParaphraseStyleType.FORMAL,), + ], +) +def test_paraphrase_styles(style: ParaphraseStyleType): + client = AI21Client() + response = client.paraphrase.create( + text="Today is a beautiful day.", + style=style, + start_index=0, + end_index=25, + ) + + assert len(response.suggestions) > 0 diff --git a/tests/integration_tests/clients/studio/test_segmentation.py b/tests/integration_tests/clients/studio/test_segmentation.py new file mode 100644 index 00000000..ede8707c --- /dev/null +++ b/tests/integration_tests/clients/studio/test_segmentation.py @@ -0,0 +1,55 @@ +import pytest +from ai21 import AI21Client +from ai21.errors import UnprocessableEntity +from ai21.models import DocumentType + +_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the +Netherlands. From the 10th to the 16th century, Holland proper was a unified political + region within the Holy Roman Empire as a county ruled by the counts of Holland. + By the 17th century, the province of Holland had risen to become a maritime and economic power, + dominating the other provinces of the newly independent Dutch Republic.""" + +_SOURCE_URL = "https://en.wikipedia.org/wiki/Holland" + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text__should_return_a_segments", + "when_source_is_url__should_return_a_segments", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.TEXT), + (_SOURCE_URL, DocumentType.URL), + ], +) +def test_segmentation(source: str, source_type: DocumentType): + client = AI21Client() + + response = client.segmentation.create( + source=source, + source_type=source_type, + ) + + assert isinstance(response.segments[0].segment_text, str) + assert response.segments[0].segment_type is not None + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text_and_source_type_is_url__should_raise_error", + # "when_source_is_url_and_source_type_is_text__should_raise_error", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.URL), + # (_SOURCE_URL, DocumentType.TEXT), + ], +) +def test_segmentation__source_and_source_type_misalignment(source: str, source_type: DocumentType): + client = AI21Client() + with pytest.raises(UnprocessableEntity): + client.segmentation.create( + source=source, + source_type=source_type, + ) diff --git a/tests/integration_tests/clients/studio/test_summarize.py b/tests/integration_tests/clients/studio/test_summarize.py new file mode 100644 index 00000000..6c7ae4e9 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_summarize.py @@ -0,0 +1,62 @@ +import pytest + +from ai21 import AI21Client +from ai21.errors import UnprocessableEntity +from ai21.models import DocumentType, SummaryMethod + +_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the +Netherlands. From the 10th to the 16th century, Holland proper was a unified political + region within the Holy Roman Empire as a county ruled by the counts of Holland. + By the 17th century, the province of Holland had risen to become a maritime and economic power, + dominating the other provinces of the newly independent Dutch Republic.""" + +_SOURCE_URL = "https://en.wikipedia.org/wiki/Holland" + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text__should_return_a_suggestion", + "when_source_is_url__should_return_a_suggestion", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.TEXT), + (_SOURCE_URL, DocumentType.URL), + ], +) +def test_summarize(source: str, source_type: DocumentType): + focus = "Holland" + + client = AI21Client() + response = client.summarize.create( + source=source, + source_type=source_type, + summary_method=SummaryMethod.SEGMENTS, + focus=focus, + ) + assert response.summary is not None + assert focus in response.summary + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text_and_source_type_is_url__should_raise_error", + "when_source_is_url_and_source_type_is_text__should_raise_error", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.URL), + (_SOURCE_URL, DocumentType.TEXT), + ], +) +def test_summarize__source_and_source_type_misalignment(source: str, source_type: DocumentType): + focus = "Holland" + + client = AI21Client() + with pytest.raises(UnprocessableEntity): + client.summarize.create( + source=source, + source_type=source_type, + summary_method=SummaryMethod.SEGMENTS, + focus=focus, + ) diff --git a/tests/integration_tests/clients/studio/test_summarize_by_segment.py b/tests/integration_tests/clients/studio/test_summarize_by_segment.py new file mode 100644 index 00000000..f39dd308 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_summarize_by_segment.py @@ -0,0 +1,66 @@ +import pytest + +from ai21 import AI21Client +from ai21.errors import UnprocessableEntity +from ai21.models import DocumentType + +_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the Netherlands. + From the 10th to the 16th century, Holland proper was a unified political + region within the Holy Roman Empire as a county ruled by the counts of Holland. + By the 17th century, the province of Holland had risen to become a maritime and economic power, + dominating the other provinces of the newly independent Dutch Republic.""" + +_SOURCE_URL = "https://en.wikipedia.org/wiki/Holland" + + +def test_summarize_by_segment__when_text__should_return_response(): + client = AI21Client() + response = client.summarize_by_segment.create( + source=_SOURCE_TEXT, + source_type=DocumentType.TEXT, + focus="Holland", + ) + assert isinstance(response.segments[0].segment_text, str) + assert response.segments[0].segment_html is None + assert isinstance(response.segments[0].summary, str) + assert len(response.segments[0].highlights) > 0 + assert response.segments[0].segment_type == "normal_text" + assert response.segments[0].has_summary + + +def test_summarize_by_segment__when_url__should_return_response(): + client = AI21Client() + response = client.summarize_by_segment.create( + source=_SOURCE_URL, + source_type=DocumentType.URL, + focus="Holland", + ) + assert isinstance(response.segments[0].segment_text, str) + assert isinstance(response.segments[0].segment_html, str) + assert isinstance(response.segments[0].summary, str) + assert response.segments[0].segment_type == "normal_text" + assert len(response.segments[0].highlights) > 0 + assert response.segments[0].has_summary + + +@pytest.mark.parametrize( + ids=[ + "when_source_is_text_and_source_type_is_url__should_raise_error", + "when_source_is_url_and_source_type_is_text__should_raise_error", + ], + argnames=["source", "source_type"], + argvalues=[ + (_SOURCE_TEXT, DocumentType.URL), + (_SOURCE_URL, DocumentType.TEXT), + ], +) +def test_summarize_by_segment__source_and_source_type_misalignment(source: str, source_type: DocumentType): + focus = "Holland" + + client = AI21Client() + with pytest.raises(UnprocessableEntity): + client.summarize_by_segment.create( + source=source, + source_type=source_type, + focus=focus, + ) From 347e7f98d905027c0b429a5a93ad1fd94cc3b36d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 30 Jan 2024 10:52:46 +0200 Subject: [PATCH 31/45] chore(deps-dev): bump jinja2 from 3.1.2 to 3.1.3 (#38) Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.2 to 3.1.3. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) - [Commits](https://github.com/pallets/jinja/compare/3.1.2...3.1.3) --- updated-dependencies: - dependency-name: jinja2 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: asafgardin <147075902+asafgardin@users.noreply.github.com> --- poetry.lock | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/poetry.lock b/poetry.lock index cd17ee01..bb8810f6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "ai21-tokenizer" @@ -399,13 +399,13 @@ requirements-deprecated-finder = ["pip-api", "pipreqs"] [[package]] name = "jinja2" -version = "3.1.2" +version = "3.1.3" description = "A very fast and expressive template engine." optional = false python-versions = ">=3.7" files = [ - {file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"}, - {file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"}, + {file = "Jinja2-3.1.3-py3-none-any.whl", hash = "sha256:7d6d50dd97d52cbc355597bd845fabfbac3f551e1f99619e39a35ce8c370b5fa"}, + {file = "Jinja2-3.1.3.tar.gz", hash = "sha256:ac8bd6544d4bb2c9792bf3a159e80bba8fda7f07e81bc3aed565432d5925ba90"}, ] [package.dependencies] @@ -1009,24 +1009,24 @@ python-versions = ">=3.6" files = [ {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b42169467c42b692c19cf539c38d4602069d8c1505e97b86387fcf7afb766e1d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:07238db9cbdf8fc1e9de2489a4f68474e70dffcb32232db7c08fa61ca0c7c462"}, - {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d92f81886165cb14d7b067ef37e142256f1c6a90a65cd156b063a43da1708cfd"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:fff3573c2db359f091e1589c3d7c5fc2f86f5bdb6f24252c2d8e539d4e45f412"}, + {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_24_aarch64.whl", hash = "sha256:aa2267c6a303eb483de8d02db2871afb5c5fc15618d894300b88958f729ad74f"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:840f0c7f194986a63d2c2465ca63af8ccbbc90ab1c6001b1978f05119b5e7334"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:024cfe1fc7c7f4e1aff4a81e718109e13409767e4f871443cbff3dba3578203d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win32.whl", hash = "sha256:c69212f63169ec1cfc9bb44723bf2917cbbd8f6191a00ef3410f5a7fe300722d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win_amd64.whl", hash = "sha256:cabddb8d8ead485e255fe80429f833172b4cadf99274db39abc080e068cbcc31"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bef08cd86169d9eafb3ccb0a39edb11d8e25f3dae2b28f5c52fd997521133069"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:b16420e621d26fdfa949a8b4b47ade8810c56002f5389970db4ddda51dbff248"}, - {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b5edda50e5e9e15e54a6a8a0070302b00c518a9d32accc2346ad6c984aacd279"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:25c515e350e5b739842fc3228d662413ef28f295791af5e5110b543cf0b57d9b"}, + {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_24_aarch64.whl", hash = "sha256:1707814f0d9791df063f8c19bb51b0d1278b8e9a2353abbb676c2f685dee6afe"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:46d378daaac94f454b3a0e3d8d78cafd78a026b1d71443f4966c696b48a6d899"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:09b055c05697b38ecacb7ac50bdab2240bfca1a0c4872b0fd309bb07dc9aa3a9"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win32.whl", hash = "sha256:53a300ed9cea38cf5a2a9b069058137c2ca1ce658a874b79baceb8f892f915a7"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win_amd64.whl", hash = "sha256:c2a72e9109ea74e511e29032f3b670835f8a59bbdc9ce692c5b4ed91ccf1eedb"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ebc06178e8821efc9692ea7544aa5644217358490145629914d8020042c24aa1"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:edaef1c1200c4b4cb914583150dcaa3bc30e592e907c01117c08b13a07255ec2"}, - {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:7048c338b6c86627afb27faecf418768acb6331fc24cfa56c93e8c9780f815fa"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d176b57452ab5b7028ac47e7b3cf644bcfdc8cacfecf7e71759f7f51a59e5c92"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_24_aarch64.whl", hash = "sha256:1dc67314e7e1086c9fdf2680b7b6c2be1c0d8e3a8279f2e993ca2a7545fecf62"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3213ece08ea033eb159ac52ae052a4899b56ecc124bb80020d9bbceeb50258e9"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aab7fd643f71d7946f2ee58cc88c9b7bfc97debd71dcc93e03e2d174628e7e2d"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-win32.whl", hash = "sha256:5c365d91c88390c8d0a8545df0b5857172824b1c604e867161e6b3d59a827eaa"}, @@ -1034,7 +1034,7 @@ files = [ {file = "ruamel.yaml.clib-0.2.8-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:a5aa27bad2bb83670b71683aae140a1f52b0857a2deff56ad3f6c13a017a26ed"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c58ecd827313af6864893e7af0a3bb85fd529f862b6adbefe14643947cfe2942"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_12_0_arm64.whl", hash = "sha256:f481f16baec5290e45aebdc2a5168ebc6d35189ae6fea7a58787613a25f6e875"}, - {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3fcc54cb0c8b811ff66082de1680b4b14cf8a81dce0d4fbf665c2265a81e07a1"}, + {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_24_aarch64.whl", hash = "sha256:77159f5d5b5c14f7c34073862a6b7d34944075d9f93e681638f6d753606c6ce6"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7f67a1ee819dc4562d444bbafb135832b0b909f81cc90f7aa00260968c9ca1b3"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4ecbf9c3e19f9562c7fdd462e8d18dd902a47ca046a2e64dba80699f0b6c09b7"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:87ea5ff66d8064301a154b3933ae406b0863402a799b16e4a1d24d9fbbcbe0d3"}, @@ -1042,7 +1042,7 @@ files = [ {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win_amd64.whl", hash = "sha256:3f215c5daf6a9d7bbed4a0a4f760f3113b10e82ff4c5c44bec20a68c8014f675"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1b617618914cb00bf5c34d4357c37aa15183fa229b24767259657746c9077615"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a6a9ffd280b71ad062eae53ac1659ad86a17f59a0fdc7699fd9be40525153337"}, - {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:665f58bfd29b167039f714c6998178d27ccd83984084c286110ef26b230f259f"}, + {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_24_aarch64.whl", hash = "sha256:305889baa4043a09e5b76f8e2a51d4ffba44259f6b4c72dec8ca56207d9c6fe1"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:700e4ebb569e59e16a976857c8798aee258dceac7c7d6b50cab63e080058df91"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e2b4c44b60eadec492926a7270abb100ef9f72798e18743939bdbf037aab8c28"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e79e5db08739731b0ce4850bed599235d601701d5694c36570a99a0c5ca41a9d"}, @@ -1050,7 +1050,7 @@ files = [ {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win_amd64.whl", hash = "sha256:56f4252222c067b4ce51ae12cbac231bce32aee1d33fbfc9d17e5b8d6966c312"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:03d1162b6d1df1caa3a4bd27aa51ce17c9afc2046c31b0ad60a0a96ec22f8001"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:bba64af9fa9cebe325a62fa398760f5c7206b215201b0ec825005f1b18b9bccf"}, - {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:9eb5dee2772b0f704ca2e45b1713e4e5198c18f515b52743576d196348f374d3"}, + {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_24_aarch64.whl", hash = "sha256:a1a45e0bb052edf6a1d3a93baef85319733a888363938e1fc9924cb00c8df24c"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:da09ad1c359a728e112d60116f626cc9f29730ff3e0e7db72b9a2dbc2e4beed5"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:184565012b60405d93838167f425713180b949e9d8dd0bbc7b49f074407c5a8b"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a75879bacf2c987c003368cf14bed0ffe99e8e85acfa6c0bfffc21a090f16880"}, From bf3e7409386057e0b6f720594539e0ef5c2823ae Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 30 Jan 2024 10:54:42 +0200 Subject: [PATCH 32/45] chore(deps-dev): bump gitpython from 3.1.40 to 3.1.41 (#37) Bumps [gitpython](https://github.com/gitpython-developers/GitPython) from 3.1.40 to 3.1.41. - [Release notes](https://github.com/gitpython-developers/GitPython/releases) - [Changelog](https://github.com/gitpython-developers/GitPython/blob/main/CHANGES) - [Commits](https://github.com/gitpython-developers/GitPython/compare/3.1.40...3.1.41) --- updated-dependencies: - dependency-name: gitpython dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: asafgardin <147075902+asafgardin@users.noreply.github.com> --- poetry.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/poetry.lock b/poetry.lock index bb8810f6..d1f0c968 100644 --- a/poetry.lock +++ b/poetry.lock @@ -314,20 +314,20 @@ smmap = ">=3.0.1,<6" [[package]] name = "gitpython" -version = "3.1.40" +version = "3.1.41" description = "GitPython is a Python library used to interact with Git repositories" optional = false python-versions = ">=3.7" files = [ - {file = "GitPython-3.1.40-py3-none-any.whl", hash = "sha256:cf14627d5a8049ffbf49915732e5eddbe8134c3bdb9d476e6182b676fc573f8a"}, - {file = "GitPython-3.1.40.tar.gz", hash = "sha256:22b126e9ffb671fdd0c129796343a02bf67bf2994b35449ffc9321aa755e18a4"}, + {file = "GitPython-3.1.41-py3-none-any.whl", hash = "sha256:c36b6634d069b3f719610175020a9aed919421c87552185b085e04fbbdb10b7c"}, + {file = "GitPython-3.1.41.tar.gz", hash = "sha256:ed66e624884f76df22c8e16066d567aaa5a37d5b5fa19db2c6df6f7156db9048"}, ] [package.dependencies] gitdb = ">=4.0.1,<5" [package.extras] -test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-instafail", "pytest-subtests", "pytest-sugar"] +test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "sumtypes"] [[package]] name = "idna" From ee75fb5ce411c9a14879adefe9800dbf89486f26 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 30 Jan 2024 11:01:49 +0200 Subject: [PATCH 33/45] chore(deps): bump actions/upload-artifact from 3 to 4 (#21) Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 3 to 4. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: asafgardin <147075902+asafgardin@users.noreply.github.com> --- .github/workflows/integration-tests.yaml | 2 +- .github/workflows/test.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index 6b12cbde..db93631e 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -75,7 +75,7 @@ jobs: run: | poetry run pytest tests/integration_tests/ - name: Upload pytest integration tests results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: pytest-results-${{ matrix.python-version }} path: junit/test-results-${{ matrix.python-version }}.xml diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 3e454388..b66a211b 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -71,7 +71,7 @@ jobs: run: | poetry run pytest tests/unittests/ - name: Upload pytest test results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: pytest-results-${{ matrix.python-version }} path: junit/test-results-${{ matrix.python-version }}.xml From 665c9533711670c358336c6674391f8afcbe8bbb Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 30 Jan 2024 11:04:04 +0200 Subject: [PATCH 34/45] chore(deps): bump pypa/gh-action-pypi-publish from 1.4.2 to 1.8.11 (#20) Bumps [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish) from 1.4.2 to 1.8.11. - [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases) - [Commits](https://github.com/pypa/gh-action-pypi-publish/compare/27b31702a0e7fc50959f5ad993c78deac1bdfc29...2f6f737ca5f74c637829c0f5c3acd0e29ea5e8bf) --- updated-dependencies: - dependency-name: pypa/gh-action-pypi-publish dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: asafgardin <147075902+asafgardin@users.noreply.github.com> --- .github/workflows/publish.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index e27ee41b..3d806b8c 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -34,7 +34,7 @@ jobs: - name: Build package run: poetry build - name: Publish package to PYPI - uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + uses: pypa/gh-action-pypi-publish@2f6f737ca5f74c637829c0f5c3acd0e29ea5e8bf with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} From 40d6df85cac9832753a1a5f5c218cc0ad0a6d4f6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 30 Jan 2024 11:07:24 +0200 Subject: [PATCH 35/45] chore(deps): bump actions/setup-python from 4 to 5 (#19) Bumps [actions/setup-python](https://github.com/actions/setup-python) from 4 to 5. - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/v4...v5) --- updated-dependencies: - dependency-name: actions/setup-python dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: asafgardin <147075902+asafgardin@users.noreply.github.com> --- .github/workflows/integration-tests.yaml | 4 ++-- .github/workflows/publish.yaml | 2 +- .github/workflows/test.yaml | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index db93631e..4b63cf92 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -24,7 +24,7 @@ jobs: run: | pipx install poetry - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: poetry @@ -56,7 +56,7 @@ jobs: run: | pipx install poetry - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: poetry diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 3d806b8c..607abb6e 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -23,7 +23,7 @@ jobs: run: | pipx install poetry - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: poetry diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index b66a211b..d679e310 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -20,7 +20,7 @@ jobs: run: | pipx install poetry - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: poetry @@ -52,7 +52,7 @@ jobs: run: | pipx install poetry - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: poetry From 03a36a9c9324a157b5aff2ecb11a1d6599ef3dee Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 30 Jan 2024 11:11:32 +0200 Subject: [PATCH 36/45] chore(deps): bump amannn/action-semantic-pull-request (#2) Bumps [amannn/action-semantic-pull-request](https://github.com/amannn/action-semantic-pull-request) from 5.0.2 to 5.4.0. - [Release notes](https://github.com/amannn/action-semantic-pull-request/releases) - [Changelog](https://github.com/amannn/action-semantic-pull-request/blob/main/CHANGELOG.md) - [Commits](https://github.com/amannn/action-semantic-pull-request/compare/v5.0.2...v5.4.0) --- updated-dependencies: - dependency-name: amannn/action-semantic-pull-request dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: asafgardin <147075902+asafgardin@users.noreply.github.com> --- .github/workflows/semantic-pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/semantic-pr.yml b/.github/workflows/semantic-pr.yml index 955b7473..2768c7de 100644 --- a/.github/workflows/semantic-pr.yml +++ b/.github/workflows/semantic-pr.yml @@ -16,7 +16,7 @@ jobs: timeout-minutes: 1 steps: - name: Semantic pull-request - uses: amannn/action-semantic-pull-request@v5.0.2 + uses: amannn/action-semantic-pull-request@v5.4.0 with: requireScope: false wip: true From e7fff8081036c9c9316ca45389a8f80cc910314a Mon Sep 17 00:00:00 2001 From: asafgardin <147075902+asafgardin@users.noreply.github.com> Date: Tue, 30 Jan 2024 18:04:02 +0200 Subject: [PATCH 37/45] test: Added tests for library (#45) * test: Added tests for library * fix: CR --- .../clients/resources/library_file.txt | 19 ++++++++ .../clients/studio/conftest.py | 45 +++++++++++++++++++ .../clients/studio/test_library.py | 37 +++++++++++++++ .../clients/studio/test_library_answer.py | 16 +++++++ .../clients/studio/test_library_search.py | 11 +++++ 5 files changed, 128 insertions(+) create mode 100644 tests/integration_tests/clients/resources/library_file.txt create mode 100644 tests/integration_tests/clients/studio/conftest.py create mode 100644 tests/integration_tests/clients/studio/test_library.py create mode 100644 tests/integration_tests/clients/studio/test_library_answer.py create mode 100644 tests/integration_tests/clients/studio/test_library_search.py diff --git a/tests/integration_tests/clients/resources/library_file.txt b/tests/integration_tests/clients/resources/library_file.txt new file mode 100644 index 00000000..174df674 --- /dev/null +++ b/tests/integration_tests/clients/resources/library_file.txt @@ -0,0 +1,19 @@ +Albert Einstein was a renowned physicist who made significant contributions to the field of theoretical physics. Born on March 14, 1879, in Ulm, in the Kingdom of Württemberg in the German Empire, Einstein's early life showed signs of his later intellectual prowess. + +Einstein attended the Swiss Federal Institute of Technology in Zurich, where he studied physics and mathematics. Despite facing challenges and financial difficulties, he persevered in his studies and graduated in 1900. After graduation, he struggled to secure a teaching position but eventually found work as a patent examiner at the Swiss Patent Office. + +In 1905, often referred to as his "miracle year," Einstein published four groundbreaking papers that transformed the scientific landscape. These papers covered the photoelectric effect, Brownian motion, special relativity, and the famous equation E=mc², demonstrating the equivalence of mass and energy. + +His theory of special relativity, published in the paper "On the Electrodynamics of Moving Bodies," challenged traditional notions of space and time. It introduced the concept of spacetime and showed that time is relative, depending on the observer's motion. + +In 1915, Einstein presented the general theory of relativity, providing a new understanding of gravitation. According to general relativity, massive objects like planets and stars cause a curvature in spacetime, influencing the motion of other objects. This theory successfully explained phenomena like the bending of light around massive objects. + +Einstein's work laid the foundation for modern cosmology and astrophysics. His predictions, such as the bending of light by gravity, were later confirmed through experiments and observations. + +Apart from his scientific endeavors, Einstein was an advocate for civil rights, pacifism, and Zionism. He spoke out against discrimination and injustice, using his platform to promote social and political causes. In 1933, Einstein fled Nazi Germany and settled in the United States, where he continued his scientific research. + +Einstein received the Nobel Prize in Physics in 1921 for his explanation of the photoelectric effect. Despite his immense contributions to science, he remained humble and often expressed a deep curiosity about the mysteries of the universe. + +In the latter part of his life, Einstein worked towards a unified field theory, attempting to combine electromagnetism and gravity into a single framework. However, this goal remained elusive, and Einstein's efforts in this direction were not as successful as his earlier work. + +Albert Einstein passed away on April 18, 1955, leaving behind a legacy that continues to shape our understanding of the physical world. His intellectual brilliance, coupled with his commitment to social justice, has made him an enduring symbol of scientific achievement and moral responsibility. The impact of Einstein's ideas extends far beyond the realm of physics, influencing fields as diverse as philosophy, literature, and popular culture. diff --git a/tests/integration_tests/clients/studio/conftest.py b/tests/integration_tests/clients/studio/conftest.py new file mode 100644 index 00000000..043f5b98 --- /dev/null +++ b/tests/integration_tests/clients/studio/conftest.py @@ -0,0 +1,45 @@ +import time +from pathlib import Path + +import pytest + +from ai21 import AI21Client + +LIBRARY_FILE_TO_UPLOAD = str(Path(__file__).parent.parent / "resources" / "library_file.txt") +DEFAULT_LABELS = ["einstein", "science"] + + +def _wait_for_file_to_process(client: AI21Client, file_id: str, timeout: float = 20): + start_time = time.time() + + elapsed_time = time.time() - start_time + while elapsed_time < timeout: + file_response = client.library.files.get(file_id) + + if file_response.status == "PROCESSED": + return + + elapsed_time = time.time() - start_time + time.sleep(0.5) + + raise TimeoutError(f"Timeout: {timeout} seconds passed. File processing not completed") + + +def _delete_file(client: AI21Client, file_id: str): + _wait_for_file_to_process(client, file_id) + client.library.files.delete(file_id) + + +@pytest.fixture(scope="module", autouse=True) +def file_in_library(): + """ + Uploads a file to the library and deletes it after the test is done + This happens in a scope of a module so the file is uploaded only once + :return: file_id: str + """ + client = AI21Client() + + file_id = client.library.files.create(file_path=LIBRARY_FILE_TO_UPLOAD, labels=DEFAULT_LABELS) + _wait_for_file_to_process(client, file_id) + yield file_id + _delete_file(client, file_id=file_id) diff --git a/tests/integration_tests/clients/studio/test_library.py b/tests/integration_tests/clients/studio/test_library.py new file mode 100644 index 00000000..7fe021e1 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_library.py @@ -0,0 +1,37 @@ +from pathlib import Path +from time import sleep + +from ai21 import AI21Client +from tests.integration_tests.clients.studio.conftest import LIBRARY_FILE_TO_UPLOAD, DEFAULT_LABELS + + +def test_library__when_upload__should_get_file_id(file_in_library: str): + assert file_in_library is not None + + +def test_library__when_list__should_get_file_id_in_list_of_files(file_in_library: str): + client = AI21Client() + + files = client.library.files.list() + assert files[0].file_id == file_in_library + assert files[0].name == Path(LIBRARY_FILE_TO_UPLOAD).name + + +def test_library__when_get__should_match_file_id(file_in_library: str): + client = AI21Client() + + file_response = client.library.files.get(file_in_library) + assert file_response.file_id == file_in_library + + +def test_library__when_update__should_update_labels_successfully(file_in_library: str): + client = AI21Client() + + file_response = client.library.files.get(file_in_library) + assert set(file_response.labels) == set(DEFAULT_LABELS) + sleep(2) + + new_labels = DEFAULT_LABELS + ["new_label"] + client.library.files.update(file_in_library, labels=new_labels) + file_response = client.library.files.get(file_in_library) + assert set(file_response.labels) == set(new_labels) diff --git a/tests/integration_tests/clients/studio/test_library_answer.py b/tests/integration_tests/clients/studio/test_library_answer.py new file mode 100644 index 00000000..b625fb0c --- /dev/null +++ b/tests/integration_tests/clients/studio/test_library_answer.py @@ -0,0 +1,16 @@ +from ai21 import AI21Client + + +def test_library_answer__when_answer_not_in_context__should_return_false(file_in_library: str): + client = AI21Client() + response = client.library.answer.create(question="Who is Tony Stark?") + assert response.answer is None + assert not response.answer_in_context + + +def test_library_answer__when_answer_in_context__should_return_true(file_in_library: str): + client = AI21Client() + response = client.library.answer.create(question="Who was Albert Einstein?") + assert response.answer is not None + assert response.answer_in_context + assert response.sources[0].file_id == file_in_library diff --git a/tests/integration_tests/clients/studio/test_library_search.py b/tests/integration_tests/clients/studio/test_library_search.py new file mode 100644 index 00000000..ac9f5f10 --- /dev/null +++ b/tests/integration_tests/clients/studio/test_library_search.py @@ -0,0 +1,11 @@ +from ai21 import AI21Client + + +def test_library_search__when_search__should_return_relevant_results(file_in_library: str): + client = AI21Client() + response = client.library.search.create( + query="What did Albert Einstein get a Nobel Prize for?", labels=["einstein"] + ) + assert len(response.results) > 0 + for result in response.results: + assert result.file_id == file_in_library From 98c14c942f55affb7cd951ff35094c80eae4621e Mon Sep 17 00:00:00 2001 From: github-actions Date: Tue, 30 Jan 2024 16:04:53 +0000 Subject: [PATCH 38/45] chore(release): v2.0.0-rc.12 [skip ci] --- CHANGELOG.md | 214 ++++++++++++++++++++++++++++++++++++++++++++++++ ai21/version.py | 2 +- pyproject.toml | 2 +- 3 files changed, 216 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f4a0aea..9335c68b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,8 +2,222 @@ +## v2.0.0-rc.12 (2024-01-30) + +### Chore + +* chore(deps): bump amannn/action-semantic-pull-request (#2) + +Bumps [amannn/action-semantic-pull-request](https://github.com/amannn/action-semantic-pull-request) from 5.0.2 to 5.4.0. +- [Release notes](https://github.com/amannn/action-semantic-pull-request/releases) +- [Changelog](https://github.com/amannn/action-semantic-pull-request/blob/main/CHANGELOG.md) +- [Commits](https://github.com/amannn/action-semantic-pull-request/compare/v5.0.2...v5.4.0) + +--- +updated-dependencies: +- dependency-name: amannn/action-semantic-pull-request + dependency-type: direct:production + update-type: version-update:semver-minor +... + +Signed-off-by: dependabot[bot] <support@github.com> +Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> +Co-authored-by: asafgardin <147075902+asafgardin@users.noreply.github.com> ([`03a36a9`](https://github.com/AI21Labs/ai21-python/commit/03a36a9c9324a157b5aff2ecb11a1d6599ef3dee)) + +* chore(deps): bump actions/setup-python from 4 to 5 (#19) + +Bumps [actions/setup-python](https://github.com/actions/setup-python) from 4 to 5. +- [Release notes](https://github.com/actions/setup-python/releases) +- [Commits](https://github.com/actions/setup-python/compare/v4...v5) + +--- +updated-dependencies: +- dependency-name: actions/setup-python + dependency-type: direct:production + update-type: version-update:semver-major +... + +Signed-off-by: dependabot[bot] <support@github.com> +Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> +Co-authored-by: asafgardin <147075902+asafgardin@users.noreply.github.com> ([`40d6df8`](https://github.com/AI21Labs/ai21-python/commit/40d6df85cac9832753a1a5f5c218cc0ad0a6d4f6)) + +* chore(deps): bump pypa/gh-action-pypi-publish from 1.4.2 to 1.8.11 (#20) + +Bumps [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish) from 1.4.2 to 1.8.11. +- [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases) +- [Commits](https://github.com/pypa/gh-action-pypi-publish/compare/27b31702a0e7fc50959f5ad993c78deac1bdfc29...2f6f737ca5f74c637829c0f5c3acd0e29ea5e8bf) + +--- +updated-dependencies: +- dependency-name: pypa/gh-action-pypi-publish + dependency-type: direct:production + update-type: version-update:semver-minor +... + +Signed-off-by: dependabot[bot] <support@github.com> +Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> +Co-authored-by: asafgardin <147075902+asafgardin@users.noreply.github.com> ([`665c953`](https://github.com/AI21Labs/ai21-python/commit/665c9533711670c358336c6674391f8afcbe8bbb)) + +* chore(deps): bump actions/upload-artifact from 3 to 4 (#21) + +Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 3 to 4. +- [Release notes](https://github.com/actions/upload-artifact/releases) +- [Commits](https://github.com/actions/upload-artifact/compare/v3...v4) + +--- +updated-dependencies: +- dependency-name: actions/upload-artifact + dependency-type: direct:production + update-type: version-update:semver-major +... + +Signed-off-by: dependabot[bot] <support@github.com> +Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> +Co-authored-by: asafgardin <147075902+asafgardin@users.noreply.github.com> ([`ee75fb5`](https://github.com/AI21Labs/ai21-python/commit/ee75fb5ce411c9a14879adefe9800dbf89486f26)) + +* chore(deps-dev): bump gitpython from 3.1.40 to 3.1.41 (#37) + +Bumps [gitpython](https://github.com/gitpython-developers/GitPython) from 3.1.40 to 3.1.41. +- [Release notes](https://github.com/gitpython-developers/GitPython/releases) +- [Changelog](https://github.com/gitpython-developers/GitPython/blob/main/CHANGES) +- [Commits](https://github.com/gitpython-developers/GitPython/compare/3.1.40...3.1.41) + +--- +updated-dependencies: +- dependency-name: gitpython + dependency-type: indirect +... + +Signed-off-by: dependabot[bot] <support@github.com> +Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> +Co-authored-by: asafgardin <147075902+asafgardin@users.noreply.github.com> ([`bf3e740`](https://github.com/AI21Labs/ai21-python/commit/bf3e7409386057e0b6f720594539e0ef5c2823ae)) + +* chore(deps-dev): bump jinja2 from 3.1.2 to 3.1.3 (#38) + +Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.2 to 3.1.3. +- [Release notes](https://github.com/pallets/jinja/releases) +- [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) +- [Commits](https://github.com/pallets/jinja/compare/3.1.2...3.1.3) + +--- +updated-dependencies: +- dependency-name: jinja2 + dependency-type: indirect +... + +Signed-off-by: dependabot[bot] <support@github.com> +Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> +Co-authored-by: asafgardin <147075902+asafgardin@users.noreply.github.com> ([`347e7f9`](https://github.com/AI21Labs/ai21-python/commit/347e7f98d905027c0b429a5a93ad1fd94cc3b36d)) + +### Ci + +* ci: Add rc branch prefix trigger for integration tests (#43) + +* ci: rc branch trigger for integration test + +* fix: wrapped in quotes ([`0eacdbb`](https://github.com/AI21Labs/ai21-python/commit/0eacdbb07e62ed415e6c8cfe0cd02b2296c63eb4)) + +### Fix + +* fix: Integration tests (#41) + +* fix: types + +* test: Added some integration tests + +* test: improvements + +* test: test_paraphrase.py + +* fix: doc + +* fix: removed unused comment + +* test: test_summarize.py + +* test: Added tests for test_summarize_by_segment.py + +* test: test_segmentation.py + +* fix: file id in library response + +* fix: example for library + +* ci: Add rc branch prefix trigger for integration tests (#43) + +* ci: rc branch trigger for integration test + +* fix: wrapped in quotes + +* fix: types + +* test: Added some integration tests + +* test: improvements + +* test: test_paraphrase.py + +* fix: doc + +* fix: removed unused comment + +* test: test_summarize.py + +* test: Added tests for test_summarize_by_segment.py + +* test: test_segmentation.py + +* fix: file id in library response + +* fix: example for library + +* docs: docstrings + +* fix: question + +* fix: CR + +* test: Added tests to segment type in embed ([`78709a7`](https://github.com/AI21Labs/ai21-python/commit/78709a7ab195fe5cf21b03d646e6e7afb6775289)) + +* fix: aws tests (#44) + +* ci: rc branch trigger for integration test + +* fix: wrapped in quotes + +* fix: AWS tests + +* test: ci + +* fix: AWS tests + +* test: ci + +* fix: Removed testing pattern for tests ([`127cef4`](https://github.com/AI21Labs/ai21-python/commit/127cef460d295b7d7a3af83848f30403cf04ba4b)) + +* fix: added user agent with more details (#42) + +* test: added user agent with more details + +* test: changed to capital + +* test: Removed class into single tests ([`d7208ab`](https://github.com/AI21Labs/ai21-python/commit/d7208ab1b3750f87fcd435fd5018c4866e51f748)) + +### Test + +* test: Added tests for library (#45) + +* test: Added tests for library + +* fix: CR ([`e7fff80`](https://github.com/AI21Labs/ai21-python/commit/e7fff8081036c9c9316ca45389a8f80cc910314a)) + + ## v2.0.0-rc.11 (2024-01-22) +### Chore + +* chore(release): v2.0.0-rc.11 [skip ci] ([`ec1f977`](https://github.com/AI21Labs/ai21-python/commit/ec1f9774924f0f44d7449753d364661f5898ed5e)) + ### Fix * fix: top_k_returns to top_k_return (#40) ([`fe3765c`](https://github.com/AI21Labs/ai21-python/commit/fe3765c800f82b9bd350a851c875c30c45aa41a8)) diff --git a/ai21/version.py b/ai21/version.py index e5e933e1..48b86f9f 100644 --- a/ai21/version.py +++ b/ai21/version.py @@ -1 +1 @@ -VERSION = "2.0.0-rc.11" +VERSION = "2.0.0-rc.12" diff --git a/pyproject.toml b/pyproject.toml index 4fd87717..bbcbd1fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ exclude_lines = [ [tool.poetry] name = "ai21" -version = "2.0.0-rc.11" +version = "2.0.0-rc.12" description = "" authors = ["AI21 Labs"] readme = "README.md" From ebffd9551a598fc277495e4dfbe9b6373c4a1bfd Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Wed, 31 Jan 2024 09:53:58 +0200 Subject: [PATCH 39/45] fix: Removed unnecessary pre commit hook --- .pre-commit-config.yaml | 9 --------- 1 file changed, 9 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 28d2342b..9c210e34 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -64,15 +64,6 @@ repos: args: - --schemafile - http://json.schemastore.org/prettierrc - - id: check-jsonschema - name: Validate ArgoWorkflow files - files: ^workflows/template/.* - types: - - yaml - args: - - --verbose - - --schemafile - - https://raw.githubusercontent.com/argoproj/argo-workflows/master/api/jsonschema/schema.json - repo: https://github.com/python-poetry/poetry rev: 1.5.0 hooks: From db17a962b88f1f5181ee81a81d0cbd7e20410918 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Wed, 31 Jan 2024 10:31:30 +0200 Subject: [PATCH 40/45] fix: Removed autouse --- tests/integration_tests/clients/studio/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests/clients/studio/conftest.py b/tests/integration_tests/clients/studio/conftest.py index 043f5b98..fcc3dfaf 100644 --- a/tests/integration_tests/clients/studio/conftest.py +++ b/tests/integration_tests/clients/studio/conftest.py @@ -30,7 +30,7 @@ def _delete_file(client: AI21Client, file_id: str): client.library.files.delete(file_id) -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(scope="module") def file_in_library(): """ Uploads a file to the library and deletes it after the test is done From d42127b9a45632051426ad387dc677d7675a873a Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Wed, 31 Jan 2024 10:59:27 +0200 Subject: [PATCH 41/45] fix: CR --- .../bedrock/resources/bedrock_completion.py | 3 +- ai21/clients/common/answer_base.py | 3 +- ai21/clients/common/chat_base.py | 4 +-- ai21/clients/common/completion_base.py | 3 +- ai21/clients/common/custom_model_base.py | 2 +- ai21/clients/common/dataset_base.py | 2 +- ai21/clients/common/embed_base.py | 3 +- ai21/clients/common/gec_base.py | 2 +- ai21/clients/common/improvements_base.py | 3 +- ai21/clients/common/paraphrase_base.py | 3 +- ai21/clients/common/segmentation_base.py | 3 +- ai21/clients/common/summarize_base.py | 3 +- .../common/summarize_by_segment_base.py | 5 +-- .../sagemaker/resources/sagemaker_answer.py | 4 +-- .../resources/sagemaker_completion.py | 3 +- .../sagemaker/resources/sagemaker_gec.py | 2 +- .../resources/sagemaker_paraphrase.py | 3 +- .../resources/sagemaker_summarize.py | 3 +- .../clients/studio/resources/studio_answer.py | 3 +- ai21/clients/studio/resources/studio_chat.py | 4 +-- .../studio/resources/studio_completion.py | 3 +- .../studio/resources/studio_custom_model.py | 2 +- .../studio/resources/studio_dataset.py | 2 +- ai21/clients/studio/resources/studio_embed.py | 3 +- ai21/clients/studio/resources/studio_gec.py | 2 +- .../studio/resources/studio_improvements.py | 3 +- .../studio/resources/studio_library.py | 5 +-- .../studio/resources/studio_paraphrase.py | 3 +- .../studio/resources/studio_segmentation.py | 3 +- .../studio/resources/studio_summarize.py | 3 +- .../resources/studio_summarize_by_segment.py | 5 +-- examples/studio/custom_model_completion.py | 32 +++++++++++++++++-- examples/studio/library.py | 3 +- examples/studio/library_answer.py | 1 - examples/studio/tokenization.py | 1 - .../clients/studio/conftest.py | 9 ++++-- .../clients/studio/resources/conftest.py | 10 ++++-- .../studio/resources/test_studio_resources.py | 2 +- tests/unittests/services/test_sagemaker.py | 2 +- 39 files changed, 79 insertions(+), 76 deletions(-) diff --git a/ai21/clients/bedrock/resources/bedrock_completion.py b/ai21/clients/bedrock/resources/bedrock_completion.py index 3f1418e3..e8617342 100644 --- a/ai21/clients/bedrock/resources/bedrock_completion.py +++ b/ai21/clients/bedrock/resources/bedrock_completion.py @@ -1,8 +1,7 @@ from typing import Optional, List from ai21.clients.bedrock.resources.bedrock_resource import BedrockResource -from ai21.models import Penalty -from ai21.models.responses.completion_response import CompletionsResponse +from ai21.models import Penalty, CompletionsResponse class BedrockCompletion(BedrockResource): diff --git a/ai21/clients/common/answer_base.py b/ai21/clients/common/answer_base.py index 1821db9a..43188d8d 100644 --- a/ai21/clients/common/answer_base.py +++ b/ai21/clients/common/answer_base.py @@ -1,8 +1,7 @@ from abc import ABC from typing import Optional, Any, Dict -from ai21.models import Mode, AnswerLength -from ai21.models.responses.answer_response import AnswerResponse +from ai21.models import Mode, AnswerLength, AnswerResponse class Answer(ABC): diff --git a/ai21/clients/common/chat_base.py b/ai21/clients/common/chat_base.py index 03037491..dee9fc4d 100644 --- a/ai21/clients/common/chat_base.py +++ b/ai21/clients/common/chat_base.py @@ -1,9 +1,7 @@ from abc import ABC, abstractmethod from typing import List, Any, Dict, Optional -from ai21.models.chat_message import ChatMessage -from ai21.models.penalty import Penalty -from ai21.models.responses.chat_response import ChatResponse +from ai21.models import Penalty, ChatResponse, ChatMessage class Chat(ABC): diff --git a/ai21/clients/common/completion_base.py b/ai21/clients/common/completion_base.py index 06fef7d8..abc338f8 100644 --- a/ai21/clients/common/completion_base.py +++ b/ai21/clients/common/completion_base.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional, List, Dict, Any -from ai21.models import Penalty -from ai21.models.responses.completion_response import CompletionsResponse +from ai21.models import Penalty, CompletionsResponse class Completion(ABC): diff --git a/ai21/clients/common/custom_model_base.py b/ai21/clients/common/custom_model_base.py index 303b1adf..83de0449 100644 --- a/ai21/clients/common/custom_model_base.py +++ b/ai21/clients/common/custom_model_base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional, List, Any, Dict -from ai21.models.responses.custom_model_response import CustomBaseModelResponse +from ai21.models import CustomBaseModelResponse class CustomModel(ABC): diff --git a/ai21/clients/common/dataset_base.py b/ai21/clients/common/dataset_base.py index 732dee39..e1f3c16d 100644 --- a/ai21/clients/common/dataset_base.py +++ b/ai21/clients/common/dataset_base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional, Any, Dict -from ai21.models.responses.dataset_response import DatasetResponse +from ai21.models import DatasetResponse class Dataset(ABC): diff --git a/ai21/clients/common/embed_base.py b/ai21/clients/common/embed_base.py index aaf9363e..d2951a5e 100644 --- a/ai21/clients/common/embed_base.py +++ b/ai21/clients/common/embed_base.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod from typing import List, Any, Dict, Optional -from ai21.models.embed_type import EmbedType -from ai21.models.responses.embed_response import EmbedResponse +from ai21.models import EmbedType, EmbedResponse class Embed(ABC): diff --git a/ai21/clients/common/gec_base.py b/ai21/clients/common/gec_base.py index 091e6427..5b05ccde 100644 --- a/ai21/clients/common/gec_base.py +++ b/ai21/clients/common/gec_base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, Any -from ai21.models.responses.gec_response import GECResponse +from ai21.models import GECResponse class GEC(ABC): diff --git a/ai21/clients/common/improvements_base.py b/ai21/clients/common/improvements_base.py index df13fe58..c3f1256a 100644 --- a/ai21/clients/common/improvements_base.py +++ b/ai21/clients/common/improvements_base.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List -from ai21.models import ImprovementType -from ai21.models.responses.improvement_response import ImprovementsResponse +from ai21.models import ImprovementType, ImprovementsResponse class Improvements(ABC): diff --git a/ai21/clients/common/paraphrase_base.py b/ai21/clients/common/paraphrase_base.py index 3c01bbb7..511feb84 100644 --- a/ai21/clients/common/paraphrase_base.py +++ b/ai21/clients/common/paraphrase_base.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional, Any, Dict -from ai21.models import ParaphraseStyleType -from ai21.models.responses.paraphrase_response import ParaphraseResponse +from ai21.models import ParaphraseStyleType, ParaphraseResponse class Paraphrase(ABC): diff --git a/ai21/clients/common/segmentation_base.py b/ai21/clients/common/segmentation_base.py index 97c74104..b4d2690e 100644 --- a/ai21/clients/common/segmentation_base.py +++ b/ai21/clients/common/segmentation_base.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict -from ai21.models.document_type import DocumentType -from ai21.models.responses.segmentation_response import SegmentationResponse +from ai21.models import DocumentType, SegmentationResponse class Segmentation(ABC): diff --git a/ai21/clients/common/summarize_base.py b/ai21/clients/common/summarize_base.py index 2a70dfcd..ee68cdb8 100644 --- a/ai21/clients/common/summarize_base.py +++ b/ai21/clients/common/summarize_base.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional, Any, Dict -from ai21.models.responses.summarize_response import SummarizeResponse -from ai21.models.summary_method import SummaryMethod +from ai21.models import SummarizeResponse, SummaryMethod class Summarize(ABC): diff --git a/ai21/clients/common/summarize_by_segment_base.py b/ai21/clients/common/summarize_by_segment_base.py index 236337de..6684dc2f 100644 --- a/ai21/clients/common/summarize_by_segment_base.py +++ b/ai21/clients/common/summarize_by_segment_base.py @@ -1,10 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional, Any, Dict -from ai21.models.document_type import DocumentType -from ai21.models.responses.summarize_by_segment_response import ( - SummarizeBySegmentResponse, -) +from ai21.models import DocumentType, SummarizeBySegmentResponse class SummarizeBySegment(ABC): diff --git a/ai21/clients/sagemaker/resources/sagemaker_answer.py b/ai21/clients/sagemaker/resources/sagemaker_answer.py index f344d6b0..03760daf 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_answer.py +++ b/ai21/clients/sagemaker/resources/sagemaker_answer.py @@ -1,8 +1,8 @@ from typing import Optional -from ai21.clients.common.answer_base import Answer, AnswerLength, Mode -from ai21.models.responses.answer_response import AnswerResponse +from ai21.clients.common.answer_base import Answer from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource +from ai21.models import AnswerResponse, AnswerLength, Mode class SageMakerAnswer(SageMakerResource, Answer): diff --git a/ai21/clients/sagemaker/resources/sagemaker_completion.py b/ai21/clients/sagemaker/resources/sagemaker_completion.py index 7aeddfaa..d850eca4 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_completion.py +++ b/ai21/clients/sagemaker/resources/sagemaker_completion.py @@ -1,8 +1,7 @@ from typing import Optional, List from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource -from ai21.models import Penalty -from ai21.models.responses.completion_response import CompletionsResponse +from ai21.models import Penalty, CompletionsResponse class SageMakerCompletion(SageMakerResource): diff --git a/ai21/clients/sagemaker/resources/sagemaker_gec.py b/ai21/clients/sagemaker/resources/sagemaker_gec.py index 0750a7ea..138ac0bf 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_gec.py +++ b/ai21/clients/sagemaker/resources/sagemaker_gec.py @@ -1,6 +1,6 @@ from ai21.clients.common.gec_base import GEC -from ai21.models.responses.gec_response import GECResponse from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource +from ai21.models import GECResponse class SageMakerGEC(SageMakerResource, GEC): diff --git a/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py b/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py index b2588019..40251d4c 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py +++ b/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py @@ -1,9 +1,8 @@ from typing import Optional from ai21.clients.common.paraphrase_base import Paraphrase -from ai21.models.paraphrase_style_type import ParaphraseStyleType -from ai21.models.responses.paraphrase_response import ParaphraseResponse from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource +from ai21.models import ParaphraseStyleType, ParaphraseResponse class SageMakerParaphrase(SageMakerResource, Paraphrase): diff --git a/ai21/clients/sagemaker/resources/sagemaker_summarize.py b/ai21/clients/sagemaker/resources/sagemaker_summarize.py index 1d5a7bc2..b8f52e0b 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_summarize.py +++ b/ai21/clients/sagemaker/resources/sagemaker_summarize.py @@ -3,9 +3,8 @@ from typing import Optional from ai21.clients.common.summarize_base import Summarize -from ai21.models.summary_method import SummaryMethod -from ai21.models.responses import SummarizeResponse from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource +from ai21.models import SummarizeResponse, SummaryMethod class SageMakerSummarize(SageMakerResource, Summarize): diff --git a/ai21/clients/studio/resources/studio_answer.py b/ai21/clients/studio/resources/studio_answer.py index a8cfc13e..6fe86c5e 100644 --- a/ai21/clients/studio/resources/studio_answer.py +++ b/ai21/clients/studio/resources/studio_answer.py @@ -1,9 +1,8 @@ from typing import Optional from ai21.clients.common.answer_base import Answer -from ai21.models import AnswerLength, Mode -from ai21.models.responses.answer_response import AnswerResponse from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models import AnswerLength, Mode, AnswerResponse class StudioAnswer(StudioResource, Answer): diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index 76008ba9..f37d9d74 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -2,9 +2,7 @@ from ai21.clients.common.chat_base import Chat from ai21.clients.studio.resources.studio_resource import StudioResource -from ai21.models.chat_message import ChatMessage -from ai21.models.penalty import Penalty -from ai21.models.responses.chat_response import ChatResponse +from ai21.models import ChatMessage, Penalty, ChatResponse class StudioChat(StudioResource, Chat): diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index 75364113..3b2cfc77 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -2,8 +2,7 @@ from ai21.clients.common.completion_base import Completion from ai21.clients.studio.resources.studio_resource import StudioResource -from ai21.models import Penalty -from ai21.models.responses.completion_response import CompletionsResponse +from ai21.models import Penalty, CompletionsResponse class StudioCompletion(StudioResource, Completion): diff --git a/ai21/clients/studio/resources/studio_custom_model.py b/ai21/clients/studio/resources/studio_custom_model.py index 831b4575..e3410930 100644 --- a/ai21/clients/studio/resources/studio_custom_model.py +++ b/ai21/clients/studio/resources/studio_custom_model.py @@ -1,8 +1,8 @@ from typing import List, Optional from ai21.clients.common.custom_model_base import CustomModel -from ai21.models.responses.custom_model_response import CustomBaseModelResponse from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models import CustomBaseModelResponse class StudioCustomModel(StudioResource, CustomModel): diff --git a/ai21/clients/studio/resources/studio_dataset.py b/ai21/clients/studio/resources/studio_dataset.py index ccfb4bac..1c77e642 100644 --- a/ai21/clients/studio/resources/studio_dataset.py +++ b/ai21/clients/studio/resources/studio_dataset.py @@ -1,8 +1,8 @@ from typing import Optional, List from ai21.clients.common.dataset_base import Dataset -from ai21.models.responses.dataset_response import DatasetResponse from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models import DatasetResponse class StudioDataset(StudioResource, Dataset): diff --git a/ai21/clients/studio/resources/studio_embed.py b/ai21/clients/studio/resources/studio_embed.py index 7e7c8fad..1ef4e4e8 100644 --- a/ai21/clients/studio/resources/studio_embed.py +++ b/ai21/clients/studio/resources/studio_embed.py @@ -2,8 +2,7 @@ from ai21.clients.common.embed_base import Embed from ai21.clients.studio.resources.studio_resource import StudioResource -from ai21.models.embed_type import EmbedType -from ai21.models.responses.embed_response import EmbedResponse +from ai21.models import EmbedType, EmbedResponse class StudioEmbed(StudioResource, Embed): diff --git a/ai21/clients/studio/resources/studio_gec.py b/ai21/clients/studio/resources/studio_gec.py index 3e716e7a..d7c0db6e 100644 --- a/ai21/clients/studio/resources/studio_gec.py +++ b/ai21/clients/studio/resources/studio_gec.py @@ -1,6 +1,6 @@ from ai21.clients.common.gec_base import GEC -from ai21.models.responses.gec_response import GECResponse from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models import GECResponse class StudioGEC(StudioResource, GEC): diff --git a/ai21/clients/studio/resources/studio_improvements.py b/ai21/clients/studio/resources/studio_improvements.py index 513a413c..8118bb66 100644 --- a/ai21/clients/studio/resources/studio_improvements.py +++ b/ai21/clients/studio/resources/studio_improvements.py @@ -3,8 +3,7 @@ from ai21.clients.common.improvements_base import Improvements from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.errors import EmptyMandatoryListError -from ai21.models import ImprovementType -from ai21.models.responses.improvement_response import ImprovementsResponse +from ai21.models import ImprovementType, ImprovementsResponse class StudioImprovements(StudioResource, Improvements): diff --git a/ai21/clients/studio/resources/studio_library.py b/ai21/clients/studio/resources/studio_library.py index fa73b123..782fad51 100644 --- a/ai21/clients/studio/resources/studio_library.py +++ b/ai21/clients/studio/resources/studio_library.py @@ -1,11 +1,8 @@ from typing import Optional, List from ai21.ai21_http_client import AI21HTTPClient -from ai21.models import Mode, AnswerLength -from ai21.models.responses.file_response import FileResponse -from ai21.models.responses.library_answer_response import LibraryAnswerResponse -from ai21.models.responses.library_search_response import LibrarySearchResponse from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models import Mode, AnswerLength, FileResponse, LibraryAnswerResponse, LibrarySearchResponse class StudioLibrary(StudioResource): diff --git a/ai21/clients/studio/resources/studio_paraphrase.py b/ai21/clients/studio/resources/studio_paraphrase.py index 686841db..70348d78 100644 --- a/ai21/clients/studio/resources/studio_paraphrase.py +++ b/ai21/clients/studio/resources/studio_paraphrase.py @@ -2,8 +2,7 @@ from ai21.clients.common.paraphrase_base import Paraphrase from ai21.clients.studio.resources.studio_resource import StudioResource -from ai21.models import ParaphraseStyleType -from ai21.models.responses.paraphrase_response import ParaphraseResponse +from ai21.models import ParaphraseStyleType, ParaphraseResponse class StudioParaphrase(StudioResource, Paraphrase): diff --git a/ai21/clients/studio/resources/studio_segmentation.py b/ai21/clients/studio/resources/studio_segmentation.py index dbda4225..0215e7df 100644 --- a/ai21/clients/studio/resources/studio_segmentation.py +++ b/ai21/clients/studio/resources/studio_segmentation.py @@ -1,7 +1,6 @@ from ai21.clients.common.segmentation_base import Segmentation from ai21.clients.studio.resources.studio_resource import StudioResource -from ai21.models.document_type import DocumentType -from ai21.models.responses.segmentation_response import SegmentationResponse +from ai21.models import DocumentType, SegmentationResponse class StudioSegmentation(StudioResource, Segmentation): diff --git a/ai21/clients/studio/resources/studio_summarize.py b/ai21/clients/studio/resources/studio_summarize.py index 4180ff52..5bab4acc 100644 --- a/ai21/clients/studio/resources/studio_summarize.py +++ b/ai21/clients/studio/resources/studio_summarize.py @@ -2,8 +2,7 @@ from ai21.clients.common.summarize_base import Summarize from ai21.clients.studio.resources.studio_resource import StudioResource -from ai21.models.responses.summarize_response import SummarizeResponse -from ai21.models.summary_method import SummaryMethod +from ai21.models import SummarizeResponse, SummaryMethod class StudioSummarize(StudioResource, Summarize): diff --git a/ai21/clients/studio/resources/studio_summarize_by_segment.py b/ai21/clients/studio/resources/studio_summarize_by_segment.py index 8292a1f9..ba54b89b 100644 --- a/ai21/clients/studio/resources/studio_summarize_by_segment.py +++ b/ai21/clients/studio/resources/studio_summarize_by_segment.py @@ -1,11 +1,8 @@ from typing import Optional from ai21.clients.common.summarize_by_segment_base import SummarizeBySegment -from ai21.models.document_type import DocumentType -from ai21.models.responses.summarize_by_segment_response import ( - SummarizeBySegmentResponse, -) from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models import SummarizeBySegmentResponse, DocumentType class StudioSummarizeBySegment(StudioResource, SummarizeBySegment): diff --git a/examples/studio/custom_model_completion.py b/examples/studio/custom_model_completion.py index 9e9118b0..f06f21cf 100644 --- a/examples/studio/custom_model_completion.py +++ b/examples/studio/custom_model_completion.py @@ -1,7 +1,35 @@ from ai21 import AI21Client - -prompt = "The following is a conversation between a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following are important points about the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\nUser gender: Male.\n\nConversation:\nUser: Hi, had a question\nMax: Hi there, happy to help!\nUser: Is there no way to return a product? I got your blue T-Shirt size small but it doesn't fit.\nMax: I'm sorry to hear that. Unfortunately we don't have a return policy. \nUser: That's a shame. \nMax: Is there anything else i can do for you?\n\n##\n\nThe following is a conversation between a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following are important points about the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, I was wondering when you'll have the \"Blue & White\" t-shirt back in stock?\nMax: Hi, happy to assist! We currently don't have it in stock. Do you want me to send you an email once we do?\nUser: Yes!\nMax: Awesome. What's your email?\nUser: anc@gmail.com\nMax: Great. I'll send you an email as soon as we get it.\n\n##\n\nThe following is a conversation between a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following are important points about the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, how much time does it take for the product to reach me?\nMax: Hi, happy to assist! It usually takes 5 working days to reach you.\nUser: Got it! thanks. Is there a way to shorten that delivery time if i pay extra?\nMax: I'm sorry, no.\nUser: Got it. How do i know if the White Crisp t-shirt will fit my size?\nMax: The size charts are available on the website.\nUser: Can you tell me what will fit a young women.\nMax: Sure. Can you share her exact size?\n\n##\n\nThe following is a conversation between a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following are important points about the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, I have a question for you" +prompt = ( + "The following is a conversation between a user of an eCommerce store and a user operation" + " associate called Max. Max is very kind and keen to help." + " The following are important points about the business policies:\n- " + "Delivery takes up to 5 days\n- There is no return option\n\nUser gender:" + " Male.\n\nConversation:\nUser: Hi, had a question\nMax: " + "Hi there, happy to help!\nUser: Is there no way to return a product?" + " I got your blue T-Shirt size small but it doesn't fit.\n" + "Max: I'm sorry to hear that. Unfortunately we don't have a return policy. \n" + "User: That's a shame. \nMax: Is there anything else i can do for you?\n\n" + "##\n\nThe following is a conversation between a user of an eCommerce store and a user operation" + " associate called Max. Max is very kind and keen to help. The following are important points about" + " the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\n" + 'User gender: Female.\n\nConversation:\nUser: Hi, I was wondering when you\'ll have the "Blue & White" ' + "t-shirt back in stock?\nMax: Hi, happy to assist! We currently don't have it in stock. Do you want me" + " to send you an email once we do?\nUser: Yes!\nMax: Awesome. What's your email?\nUser: anc@gmail.com\n" + "Max: Great. I'll send you an email as soon as we get it.\n\n##\n\nThe following is a conversation between" + " a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help." + " The following are important points about the business policies:\n- Delivery takes up to 5 days\n" + "- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, how much time does it" + " take for the product to reach me?\nMax: Hi, happy to assist! It usually takes 5 working" + " days to reach you.\nUser: Got it! thanks. Is there a way to shorten that delivery time if i pay extra?\n" + "Max: I'm sorry, no.\nUser: Got it. How do i know if the White Crisp t-shirt will fit my size?\n" + "Max: The size charts are available on the website.\nUser: Can you tell me what will fit a young women.\n" + "Max: Sure. Can you share her exact size?\n\n##\n\nThe following is a conversation between a user of an" + " eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following" + " are important points about the business policies:\n- Delivery takes up to 5 days\n" + "- There is no return option\n\nUser gender: Female.\n\nConversation:\n" + "User: Hi, I have a question for you" +) client = AI21Client() response = client.completion.create( diff --git a/examples/studio/library.py b/examples/studio/library.py index ca8e4840..0ec21c95 100644 --- a/examples/studio/library.py +++ b/examples/studio/library.py @@ -2,8 +2,7 @@ import uuid import file_utils -from ai21 import AI21Client -from ai21.errors import AI21APIError +from ai21 import AI21Client, AI21APIError # Use api_host for testing staging, default is production # os.environ["AI21_API_HOST"] = "https://api-stage.ai21.com" diff --git a/examples/studio/library_answer.py b/examples/studio/library_answer.py index 20d46402..1c2ae02f 100644 --- a/examples/studio/library_answer.py +++ b/examples/studio/library_answer.py @@ -1,6 +1,5 @@ from ai21 import AI21Client - client = AI21Client() response = client.library.answer.create(question="Can you tell me something about Holland?") print(response) diff --git a/examples/studio/tokenization.py b/examples/studio/tokenization.py index 21407185..2e9fb063 100644 --- a/examples/studio/tokenization.py +++ b/examples/studio/tokenization.py @@ -1,6 +1,5 @@ from ai21 import AI21Client - prompt = ( "The following is a conversation between a user of an eCommerce store and a user operation" " associate called Max. Max is very kind and keen to help." diff --git a/tests/integration_tests/clients/studio/conftest.py b/tests/integration_tests/clients/studio/conftest.py index fcc3dfaf..47cc7026 100644 --- a/tests/integration_tests/clients/studio/conftest.py +++ b/tests/integration_tests/clients/studio/conftest.py @@ -25,7 +25,7 @@ def _wait_for_file_to_process(client: AI21Client, file_id: str, timeout: float = raise TimeoutError(f"Timeout: {timeout} seconds passed. File processing not completed") -def _delete_file(client: AI21Client, file_id: str): +def _delete_uploaded_file(client: AI21Client, file_id: str): _wait_for_file_to_process(client, file_id) client.library.files.delete(file_id) @@ -39,7 +39,12 @@ def file_in_library(): """ client = AI21Client() + # Delete any file that might be in the library due to failed tests + files = client.library.files.list() + for file in files: + client.library.files.delete(file.file_id) + file_id = client.library.files.create(file_path=LIBRARY_FILE_TO_UPLOAD, labels=DEFAULT_LABELS) _wait_for_file_to_process(client, file_id) yield file_id - _delete_file(client, file_id=file_id) + _delete_uploaded_file(client, file_id=file_id) diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index 8d038dd5..0ff39d68 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -5,9 +5,13 @@ from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_chat import StudioChat from ai21.clients.studio.resources.studio_completion import StudioCompletion -from ai21.models import AnswerResponse, ChatMessage, RoleType, ChatResponse -from ai21.models.responses.chat_response import ChatOutput, FinishReason -from ai21.models.responses.completion_response import ( +from ai21.models import ( + AnswerResponse, + ChatMessage, + RoleType, + ChatResponse, + ChatOutput, + FinishReason, Prompt, Completion, CompletionData, diff --git a/tests/unittests/clients/studio/resources/test_studio_resources.py b/tests/unittests/clients/studio/resources/test_studio_resources.py index cc534fe2..4ce6308e 100644 --- a/tests/unittests/clients/studio/resources/test_studio_resources.py +++ b/tests/unittests/clients/studio/resources/test_studio_resources.py @@ -5,7 +5,7 @@ from ai21.ai21_http_client import AI21HTTPClient from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_resource import StudioResource -from ai21.models.responses.answer_response import AnswerResponse +from ai21.models import AnswerResponse from tests.unittests.clients.studio.resources.conftest import get_studio_answer, get_studio_chat, get_studio_completion _BASE_URL = "https://test.api.ai21.com/studio/v1" diff --git a/tests/unittests/services/test_sagemaker.py b/tests/unittests/services/test_sagemaker.py index dd36e1c9..c6f9e165 100644 --- a/tests/unittests/services/test_sagemaker.py +++ b/tests/unittests/services/test_sagemaker.py @@ -1,6 +1,6 @@ import pytest -from ai21.errors import ModelPackageDoesntExistError +from ai21 import ModelPackageDoesntExistError from tests.unittests.services.sagemaker_stub import SageMakerStub _DUMMY_ARN = "some-model-package-id1" From 934e0e7c042d04bb4fff9b538e0df1a42422b91b Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Wed, 31 Jan 2024 16:09:18 +0200 Subject: [PATCH 42/45] docs: CONTRIBUTING.md --- CONTRIBUTING.md | 103 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 CONTRIBUTING.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..a71b134f --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,103 @@ +# Contributing to AI21 Python SDK + +We welcome contributions to the AI21 Python SDK. Please read the following guidelines before submitting your pull request. + +### Examples of contributions include: + +- Bug fixes +- Documentation improvements +- Additional tests + +## Reporting issues + +Go to this repository's [issues page](https://github.com/AI21Labs/ai21-python/issues) and click on the "New Issue" button. +Please make sure to check if the issue has already been reported before creating a new one. + +Include the following information in your post: + +- Describe what you expected to happen. +- If possible, include a [minimal reproducible example](https://stackoverflow.com/help/minimal-reproducible-example) to help us + identify the issue. This also helps check that the issue is not with + your own code. +- Describe what actually happened. Include the full traceback if there + was an exception. +- List your Python version. If possible, check if this + issue is already fixed in the latest releases or the latest code in + the repository. + +## Submit a pull request + +Fork the AI21 Python SDK repository and clone it to your local machine. Create a new branch for your changes: + + git clone https://github.com:AI21Labs/USERNAME/ai21-python + cd ai21-python + git checkout -b my-fix-branch master + +### Installation + +#### MacOS + +We recommend running the provided `init.sh` script to install the required dependencies and set up the development environment. This script will install poetry if not already installed. To run the script, simply run: + + ./init.sh + +#### Windows/Linux + +We recommend using [poetry](https://python-poetry.org/) to install the required dependencies and set up the development environment. To install poetry, run: + + pip install poetry + +Then, to install the required dependencies, run: + + poetry install + +After that Install [pre-commit](https://pre-commit.com/#installation) and run: + + pre-commit install --install-hooks -t pre-commit -t commit-msg + +Installing the pre-commit hooks would take care of formatting and linting your code before committing. +Please make sure you have the pre-commit hooks installed before committing your code. + +**We recommend creating your own venv using pyenv or virtualenv when working on this repository, in order to eliminate unnecessary dependencies from external libraries** + +### Commits + +Each commit should be a single logical change and should be aligned with the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification. +Since we are using a pre-commit hook to enforce this, any other commit message format will be rejected. + +### Run CI tasks locally + +```bash +$ inv --list +Available tasks: + + clean clean (remove) packages + lint python lint + outdated outdated packages + test Run unit tests + update update packages + audit run safety checks on project dependencies + formatter auto formats the modified files +``` + +### Tests + +We use [pytest](https://docs.pytest.org/en/stable/) for testing. To run the tests, run: + + inv test + +If adding a new test, please make sure to add it to the `tests` directory and have the file location be under the same hierarchy as the file being tested. + +Make sure you use `pytest` for tests writing and not any other testing framework. + +### How to open a pull request? + +Push your branch to your forked repository and open a pull request against the `main` branch of the AI21 Python SDK repository. Please make sure to include a description of your changes in the pull request. + +The title of the pull request should follow the above-mentioned [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification. + +### Feedback + +If you have any questions or feedback, please feel free to reach out to us. + +We appreciate and encourage any contributions to the AI21 Python SDK. Please take the reviewer feedback positively and make the necessary changes to your pull request. From ca0703d02589e50b51702ef88e97d000bfa923a1 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Wed, 31 Jan 2024 16:10:25 +0200 Subject: [PATCH 43/45] docs: LICENSE --- LICENSE | 201 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 LICENSE diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..261eeb9e --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. From 7cca7191299bb1a370c343513de489bdef52af62 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Wed, 31 Jan 2024 16:18:40 +0200 Subject: [PATCH 44/45] fix: removed license --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 8db436ba..2ba17979 100755 --- a/setup.py +++ b/setup.py @@ -14,7 +14,6 @@ setup( name="ai21", version=VERSION, - license="MIT", author="AI21 Labs", author_email="support@ai21.com", long_description_content_type="text/markdown", From 1782de8b31133db9cac1756da167ec819966f783 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Wed, 31 Jan 2024 17:17:12 +0200 Subject: [PATCH 45/45] test: Added some more unittests --- .../studio/resources/studio_segmentation.py | 2 +- .../models/responses/segmentation_response.py | 1 + .../summarize_by_segment_response.py | 1 + .../clients/studio/resources/conftest.py | 182 +++++++++++++++++- .../studio/resources/test_studio_resources.py | 27 ++- 5 files changed, 209 insertions(+), 4 deletions(-) diff --git a/ai21/clients/studio/resources/studio_segmentation.py b/ai21/clients/studio/resources/studio_segmentation.py index 0215e7df..11d92814 100644 --- a/ai21/clients/studio/resources/studio_segmentation.py +++ b/ai21/clients/studio/resources/studio_segmentation.py @@ -5,7 +5,7 @@ class StudioSegmentation(StudioResource, Segmentation): def create(self, source: str, source_type: DocumentType, **kwargs) -> SegmentationResponse: - body = self._create_body(source=source, source_type=source_type) + body = self._create_body(source=source, source_type=source_type.value) url = f"{self._client.get_base_url()}/{self._module_name}" raw_response = self._post(url=url, body=body) diff --git a/ai21/models/responses/segmentation_response.py b/ai21/models/responses/segmentation_response.py index 1ceb1120..ed7021ac 100644 --- a/ai21/models/responses/segmentation_response.py +++ b/ai21/models/responses/segmentation_response.py @@ -12,4 +12,5 @@ class Segment(AI21BaseModelMixin): @dataclass class SegmentationResponse(AI21BaseModelMixin): + id: str segments: List[Segment] diff --git a/ai21/models/responses/summarize_by_segment_response.py b/ai21/models/responses/summarize_by_segment_response.py index 4f099a36..d780454f 100644 --- a/ai21/models/responses/summarize_by_segment_response.py +++ b/ai21/models/responses/summarize_by_segment_response.py @@ -23,4 +23,5 @@ class SegmentSummary(AI21BaseModelMixin): @dataclass class SummarizeBySegmentResponse(AI21BaseModelMixin): + id: str segments: List[SegmentSummary] diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index 0ff39d68..1849cf42 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -5,6 +5,13 @@ from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_chat import StudioChat from ai21.clients.studio.resources.studio_completion import StudioCompletion +from ai21.clients.studio.resources.studio_embed import StudioEmbed +from ai21.clients.studio.resources.studio_gec import StudioGEC +from ai21.clients.studio.resources.studio_improvements import StudioImprovements +from ai21.clients.studio.resources.studio_paraphrase import StudioParaphrase +from ai21.clients.studio.resources.studio_segmentation import StudioSegmentation +from ai21.clients.studio.resources.studio_summarize import StudioSummarize +from ai21.clients.studio.resources.studio_summarize_by_segment import StudioSummarizeBySegment from ai21.models import ( AnswerResponse, ChatMessage, @@ -17,7 +24,26 @@ CompletionData, CompletionFinishReason, CompletionsResponse, + EmbedType, + EmbedResponse, + EmbedResult, + GECResponse, + Correction, + CorrectionType, + ImprovementType, + ImprovementsResponse, + Improvement, + ParaphraseStyleType, + ParaphraseResponse, + Suggestion, + DocumentType, + SegmentationResponse, + SummaryMethod, + SummarizeResponse, + SummarizeBySegmentResponse, + SegmentSummary, ) +from ai21.models.responses.segmentation_response import Segment @pytest.fixture @@ -122,5 +148,157 @@ def get_studio_completion(): ) -def get_studio_custom_model(): - pass +def get_studio_embed(): + return ( + StudioEmbed, + {"texts": ["text0", "text1"], "type": EmbedType.QUERY}, + "embed", + { + "texts": ["text0", "text1"], + "type": EmbedType.QUERY.value, + }, + EmbedResponse( + id="some-id", + results=[ + EmbedResult([1.0, 2.0, 3.0]), + EmbedResult([4.0, 5.0, 6.0]), + ], + ), + ) + + +def get_studio_gec(): + text = "text to fi" + return ( + StudioGEC, + {"text": text}, + "gec", + { + "text": text, + }, + GECResponse( + id="some-id", + corrections=[ + Correction( + suggestion="text to fix", + start_index=9, + end_index=10, + original_text=text, + correction_type=CorrectionType.SPELLING, + ) + ], + ), + ) + + +def get_studio_improvements(): + text = "text to improve" + types = [ImprovementType.FLUENCY] + return ( + StudioImprovements, + {"text": text, "types": types}, + "improvements", + { + "text": text, + "types": types, + }, + ImprovementsResponse( + id="some-id", + improvements=[ + Improvement( + suggestions=["This text is improved"], + start_index=0, + end_index=15, + original_text=text, + improvement_type=ImprovementType.FLUENCY, + ) + ], + ), + ) + + +def get_studio_paraphrase(): + text = "text to paraphrase" + style = ParaphraseStyleType.CASUAL + start_index = 0 + end_index = 10 + return ( + StudioParaphrase, + {"text": text, "style": style, "start_index": start_index, "end_index": end_index}, + "paraphrase", + { + "text": text, + "style": style, + "startIndex": start_index, + "endIndex": end_index, + }, + ParaphraseResponse(id="some-id", suggestions=[Suggestion(text="This text is paraphrased")]), + ) + + +def get_studio_segmentation(): + source = "segmentation text" + source_type = DocumentType.TEXT + return ( + StudioSegmentation, + {"source": source, "source_type": source_type}, + "segmentation", + { + "source": source, + "sourceType": source_type, + }, + SegmentationResponse( + id="some-id", segments=[Segment(segment_text="This text is segmented", segment_type="segment_type")] + ), + ) + + +def get_studio_summarization(): + source = "text to summarize" + source_type = "TEXT" + focus = "text" + summary_method = SummaryMethod.FULL_DOCUMENT + return ( + StudioSummarize, + {"source": source, "source_type": source_type, "focus": focus, "summary_method": summary_method}, + "summarize", + { + "source": source, + "sourceType": source_type, + "focus": focus, + "summaryMethod": summary_method, + }, + SummarizeResponse( + id="some-id", + summary="This text is summarized", + ), + ) + + +def get_studio_summarize_by_segment(): + source = "text to summarize" + source_type = "TEXT" + focus = "text" + return ( + StudioSummarizeBySegment, + {"source": source, "source_type": source_type, "focus": focus}, + "summarize-by-segment", + { + "source": source, + "sourceType": source_type, + "focus": focus, + }, + SummarizeBySegmentResponse( + id="some-id", + segments=[ + SegmentSummary( + summary="This text is summarized", + segment_text="This text is segmented", + segment_type="segment_type", + segment_html=None, + has_summary=True, + highlights=[], + ) + ], + ), + ) diff --git a/tests/unittests/clients/studio/resources/test_studio_resources.py b/tests/unittests/clients/studio/resources/test_studio_resources.py index 4ce6308e..cea0df9d 100644 --- a/tests/unittests/clients/studio/resources/test_studio_resources.py +++ b/tests/unittests/clients/studio/resources/test_studio_resources.py @@ -6,7 +6,18 @@ from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.models import AnswerResponse -from tests.unittests.clients.studio.resources.conftest import get_studio_answer, get_studio_chat, get_studio_completion +from tests.unittests.clients.studio.resources.conftest import ( + get_studio_answer, + get_studio_chat, + get_studio_completion, + get_studio_embed, + get_studio_gec, + get_studio_improvements, + get_studio_paraphrase, + get_studio_segmentation, + get_studio_summarization, + get_studio_summarize_by_segment, +) _BASE_URL = "https://test.api.ai21.com/studio/v1" _DUMMY_CONTEXT = "What is the answer to life, the universe and everything?" @@ -21,12 +32,26 @@ class TestStudioResources: "studio_answer", "studio_chat", "studio_completion", + "studio_embed", + "studio_gec", + "studio_improvements", + "studio_paraphrase", + "studio_segmentation", + "studio_summarization", + "studio_summarize_by_segment", ], argnames=["studio_resource", "function_body", "url_suffix", "expected_body", "expected_response"], argvalues=[ (get_studio_answer()), (get_studio_chat()), (get_studio_completion()), + (get_studio_embed()), + (get_studio_gec()), + (get_studio_improvements()), + (get_studio_paraphrase()), + (get_studio_segmentation()), + (get_studio_summarization()), + (get_studio_summarize_by_segment()), ], ) def test__create__should_return_answer_response(