From 2afbaf3d697d1ef68f68049ad1b9b14ba528fedc Mon Sep 17 00:00:00 2001 From: Jakob Herpel Date: Mon, 15 Jan 2024 10:28:17 +0100 Subject: [PATCH] Return num_tokens_prompt_total for evaluation --- Changelog.md | 5 +++++ aleph_alpha_client/evaluation.py | 2 ++ aleph_alpha_client/version.py | 4 ++-- tests/test_evaluate.py | 8 +++++--- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/Changelog.md b/Changelog.md index 7c67a83..2d33c19 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,5 +1,10 @@ # Changelog +## 7.0.0 + +- Added `num_tokens_prompt_total` to `EvaluationResponse` +- HTTP API version 1.16.0 or higher is required. + ## 6.0.0 - Added `num_tokens_prompt_total` to the types below. diff --git a/aleph_alpha_client/evaluation.py b/aleph_alpha_client/evaluation.py index 3f8ebf1..01486de 100644 --- a/aleph_alpha_client/evaluation.py +++ b/aleph_alpha_client/evaluation.py @@ -53,11 +53,13 @@ class EvaluationResponse: model_version: str message: Optional[str] result: Dict[str, Any] + num_tokens_prompt_total: int @staticmethod def from_json(json: Dict[str, Any]) -> "EvaluationResponse": return EvaluationResponse( model_version=json["model_version"], result=json["result"], + num_tokens_prompt_total=json["num_tokens_prompt_total"], message=json.get("message"), ) diff --git a/aleph_alpha_client/version.py b/aleph_alpha_client/version.py index 4bd1e5e..85eb590 100644 --- a/aleph_alpha_client/version.py +++ b/aleph_alpha_client/version.py @@ -1,2 +1,2 @@ -__version__ = "6.0.0" -MIN_API_VERSION = "1.15.0" +__version__ = "7.0.0" +MIN_API_VERSION = "1.16.0" diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index c6f6f6b..465f18f 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -23,6 +23,7 @@ async def test_can_evaluate_with_async_client( response = await async_client.evaluate(request, model=model_name) assert response.model_version is not None assert response.result is not None + assert response.num_tokens_prompt_total >= 1 # Client @@ -34,7 +35,8 @@ def test_evaluate(sync_client: Client, model_name: str): prompt=Prompt.from_text("hello"), completion_expected="world" ) - result = sync_client.evaluate(request, model=model_name) + response = sync_client.evaluate(request, model=model_name) - assert result.model_version is not None - assert result.result is not None + assert response.model_version is not None + assert response.result is not None + assert response.num_tokens_prompt_total >= 1