diff --git a/aixplain/enums/__init__.py b/aixplain/enums/__init__.py index d66facce..947d59a9 100644 --- a/aixplain/enums/__init__.py +++ b/aixplain/enums/__init__.py @@ -13,4 +13,4 @@ from .supplier import Supplier from .sort_by import SortBy from .sort_order import SortOrder -from .model_status import ModelStatus +from .response_status import ResponseStatus \ No newline at end of file diff --git a/aixplain/enums/model_status.py b/aixplain/enums/model_status.py deleted file mode 100644 index af4ae0a9..00000000 --- a/aixplain/enums/model_status.py +++ /dev/null @@ -1,11 +0,0 @@ -from enum import Enum -from typing import Text - - -class ModelStatus(Text, Enum): - FAILED = "FAILED" - IN_PROGRESS = "IN_PROGRESS" - SUCCESS = "SUCCESS" - - def __str__(self): - return self._value_ diff --git a/aixplain/enums/response_status.py b/aixplain/enums/response_status.py new file mode 100644 index 00000000..d2810753 --- /dev/null +++ b/aixplain/enums/response_status.py @@ -0,0 +1,31 @@ +__author__ = "thiagocastroferreira" + +""" +Copyright 2024 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli +Date: February 21st 2024 +Description: + Asset Enum +""" + +from enum import Enum +from typing import Text + + +class ResponseStatus(Text, Enum): + IN_PROGRESS = "IN_PROGRESS" + SUCCESS = "SUCCESS" + FAILED = "FAILED" diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 41abf865..2ac8b37f 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -32,7 +32,7 @@ from typing import Union, Optional, Text, Dict from datetime import datetime from aixplain.modules.model.response import ModelResponse -from aixplain.enums import ModelStatus +from aixplain.enums.response_status import ResponseStatus class Model(Asset): @@ -118,7 +118,9 @@ def __repr__(self): except Exception: return f"" - def sync_poll(self, poll_url: Text, name: Text = "model_process", wait_time: float = 0.5, timeout: float = 300) -> Dict: + def sync_poll( + self, poll_url: Text, name: Text = "model_process", wait_time: float = 0.5, timeout: float = 300 + ) -> ModelResponse: """Keeps polling the platform to check whether an asynchronous call is done. Args: @@ -135,7 +137,7 @@ def sync_poll(self, poll_url: Text, name: Text = "model_process", wait_time: flo # keep wait time as 0.2 seconds the minimum wait_time = max(wait_time, 0.2) completed = False - response_body = {"status": "FAILED", "completed": False} + response_body = ModelResponse(status=ResponseStatus.FAILED, completed=False) while not completed and (end - start) < timeout: try: response_body = self.poll(poll_url, name=name) @@ -147,13 +149,17 @@ def sync_poll(self, poll_url: Text, name: Text = "model_process", wait_time: flo if wait_time < 60: wait_time *= 1.1 except Exception as e: - response_body = {"status": "FAILED", "completed": False, "error_message": "No response from the service."} + response_body = ModelResponse( + status=ResponseStatus.FAILED, completed=False, error_message="No response from the service." + ) logging.error(f"Polling for Model: polling for {name}: {e}") break if response_body["completed"] is True: logging.debug(f"Polling for Model: Final status of polling for {name}: {response_body}") else: - response_body["status"] = "FAILED" + response_body = ModelResponse( + status=ResponseStatus.FAILED, completed=False, error_message="No response from the service." + ) logging.error( f"Polling for Model: Final status of polling for {name}: No response in {timeout} seconds - {response_body}" ) @@ -174,11 +180,11 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> ModelResponse: try: resp = r.json() if resp["completed"] is True: - status = ModelStatus.SUCCESS + status = ResponseStatus.SUCCESS if "error_message" in resp or "supplierError" in resp: - status = ModelStatus.FAILED + status = ResponseStatus.FAILED else: - status = ModelStatus.IN_PROGRESS + status = ResponseStatus.IN_PROGRESS logging.debug(f"Single Poll for Model: Status of polling for {name}: {resp}") return ModelResponse( status=resp.pop("status", status), @@ -195,7 +201,7 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> ModelResponse: resp = {"status": "FAILED"} logging.error(f"Single Poll for Model: Error of polling for {name}: {e}") return ModelResponse( - status=ModelStatus.FAILED, + status=ResponseStatus.FAILED, error_message=str(e), completed=False, ) @@ -234,9 +240,9 @@ def run( msg = f"Error in request for {name} - {traceback.format_exc()}" logging.error(f"Model Run: Error in running for {name}: {e}") end = time.time() - response = {"status": "FAILED", "error": msg, "runTime": end - start} + response = {"status": "FAILED", "error_message": msg, "runTime": end - start} return ModelResponse( - status=response.pop("status", ModelStatus.FAILED), + status=response.pop("status", ResponseStatus.FAILED), data=response.pop("data", ""), details=response.pop("details", {}), completed=response.pop("completed", False), @@ -247,7 +253,9 @@ def run( **response, ) - def run_async(self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = {}) -> ModelResponse: + def run_async( + self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = {} + ) -> ModelResponse: """Runs asynchronously a model call. Args: @@ -263,7 +271,7 @@ def run_async(self, data: Union[Text, Dict], name: Text = "model_process", param payload = build_payload(data=data, parameters=parameters) response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key) return ModelResponse( - status=response.pop("status", ModelStatus.FAILED), + status=response.pop("status", ResponseStatus.FAILED), data=response.pop("data", ""), details=response.pop("details", {}), completed=response.pop("completed", False), diff --git a/aixplain/modules/model/llm_model.py b/aixplain/modules/model/llm_model.py index 941c4a6f..48bfcc11 100644 --- a/aixplain/modules/model/llm_model.py +++ b/aixplain/modules/model/llm_model.py @@ -29,7 +29,7 @@ from aixplain.utils import config from typing import Union, Optional, List, Text, Dict from aixplain.modules.model.response import ModelResponse -from aixplain.enums import ModelStatus +from aixplain.enums.response_status import ResponseStatus class LLM(Model): @@ -152,7 +152,7 @@ def run( end = time.time() response = {"status": "FAILED", "error": msg, "elapsed_time": end - start} return ModelResponse( - status=response.pop("status", ModelStatus.FAILED), + status=response.pop("status", ResponseStatus.FAILED), data=response.pop("data", ""), details=response.pop("details", {}), completed=response.pop("completed", False), @@ -206,7 +206,7 @@ def run_async( payload = build_payload(data=data, parameters=parameters) response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key) return ModelResponse( - status=response.pop("status", ModelStatus.FAILED), + status=response.pop("status", ResponseStatus.FAILED), data=response.pop("data", ""), details=response.pop("details", {}), completed=response.pop("completed", False), diff --git a/aixplain/modules/model/response.py b/aixplain/modules/model/response.py index 42ed09a4..94ddcb9d 100644 --- a/aixplain/modules/model/response.py +++ b/aixplain/modules/model/response.py @@ -1,15 +1,13 @@ -from dataclasses import dataclass from typing import Text, Any, Optional, Dict, List, Union -from aixplain.enums import ModelStatus +from aixplain.enums import ResponseStatus -@dataclass class ModelResponse: """ModelResponse class to store the response of the model run.""" def __init__( self, - status: ModelStatus, + status: ResponseStatus, data: Text = "", details: Optional[Union[Dict, List]] = {}, completed: bool = False, diff --git a/tests/unit/llm_test.py b/tests/unit/llm_test.py index 1329e136..073ed3ac 100644 --- a/tests/unit/llm_test.py +++ b/tests/unit/llm_test.py @@ -4,7 +4,7 @@ load_dotenv() from aixplain.utils import config -from aixplain.enums import ModelStatus +from aixplain.enums import ResponseStatus from aixplain.modules.model.response import ModelResponse from aixplain.modules import LLM @@ -85,7 +85,7 @@ def test_run_sync(): response = test_model.run(data=input_data, temperature=0.001, max_tokens=128, top_p=1.0) assert isinstance(response, ModelResponse) - assert response.status == ModelStatus.SUCCESS + assert response.status == ResponseStatus.SUCCESS assert response.data == "Test Model Result" assert response.completed is True assert response.used_credits == 0 diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index 94e2f6c2..9ddb6bc0 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -27,7 +27,7 @@ from aixplain.factories import ModelFactory from aixplain.enums import Function from urllib.parse import urljoin -from aixplain.enums import ModelStatus +from aixplain.enums import ResponseStatus from aixplain.modules.model.response import ModelResponse import pytest from unittest.mock import patch @@ -67,7 +67,7 @@ def test_call_run_endpoint_sync(): model_id = "model-id" execute_url = f"{base_url}/{model_id}".replace("/api/v1/execute", "/api/v2/execute") payload = {"data": "input_data"} - ref_response = {"completed": True, "status": ModelStatus.SUCCESS, "data": "Hello"} + ref_response = {"completed": True, "status": ResponseStatus.SUCCESS, "data": "Hello"} with requests_mock.Mocker() as mock: mock.post(execute_url, json=ref_response) @@ -88,7 +88,7 @@ def test_success_poll(): hyp_response = test_model.poll(poll_url=poll_url) assert isinstance(hyp_response, ModelResponse) assert hyp_response["completed"] == ref_response["completed"] - assert hyp_response["status"] == ModelStatus.SUCCESS + assert hyp_response["status"] == ResponseStatus.SUCCESS def test_failed_poll(): @@ -103,7 +103,7 @@ def test_failed_poll(): response = model.poll(poll_url=poll_url) assert isinstance(response, ModelResponse) - assert response.status == ModelStatus.FAILED + assert response.status == ResponseStatus.FAILED assert response.error_message == "Some error occurred" assert response.completed is True @@ -145,7 +145,7 @@ def test_run_async_errors(status_code, error_message): test_model = Model(id=model_id, name="Test Model", url=base_url) response = test_model.run_async(data="input_data") assert isinstance(response, ModelResponse) - assert response["status"] == ModelStatus.FAILED + assert response["status"] == ResponseStatus.FAILED assert response["error_message"] == error_message @@ -219,7 +219,7 @@ def test_run_sync(): response = test_model.run(data=input_data, name="test_run") assert isinstance(response, ModelResponse) - assert response.status == ModelStatus.SUCCESS + assert response.status == ResponseStatus.SUCCESS assert response.data == "Test Model Result" assert response.completed is True assert response.used_credits == 0