Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aixplain/enums/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
from .supplier import Supplier
from .sort_by import SortBy
from .sort_order import SortOrder
from .model_status import ModelStatus
from .response_status import ResponseStatus
11 changes: 0 additions & 11 deletions aixplain/enums/model_status.py

This file was deleted.

31 changes: 31 additions & 0 deletions aixplain/enums/response_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
__author__ = "thiagocastroferreira"

"""
Copyright 2024 The aiXplain SDK authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli
Date: February 21st 2024
Description:
Asset Enum
"""

from enum import Enum
from typing import Text


class ResponseStatus(Text, Enum):
IN_PROGRESS = "IN_PROGRESS"
SUCCESS = "SUCCESS"
FAILED = "FAILED"
34 changes: 21 additions & 13 deletions aixplain/modules/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from typing import Union, Optional, Text, Dict
from datetime import datetime
from aixplain.modules.model.response import ModelResponse
from aixplain.enums import ModelStatus
from aixplain.enums.response_status import ResponseStatus


class Model(Asset):
Expand Down Expand Up @@ -118,7 +118,9 @@ def __repr__(self):
except Exception:
return f"<Model: {self.name} by {self.supplier}>"

def sync_poll(self, poll_url: Text, name: Text = "model_process", wait_time: float = 0.5, timeout: float = 300) -> Dict:
def sync_poll(
self, poll_url: Text, name: Text = "model_process", wait_time: float = 0.5, timeout: float = 300
) -> ModelResponse:
"""Keeps polling the platform to check whether an asynchronous call is done.

Args:
Expand All @@ -135,7 +137,7 @@ def sync_poll(self, poll_url: Text, name: Text = "model_process", wait_time: flo
# keep wait time as 0.2 seconds the minimum
wait_time = max(wait_time, 0.2)
completed = False
response_body = {"status": "FAILED", "completed": False}
response_body = ModelResponse(status=ResponseStatus.FAILED, completed=False)
while not completed and (end - start) < timeout:
try:
response_body = self.poll(poll_url, name=name)
Expand All @@ -147,13 +149,17 @@ def sync_poll(self, poll_url: Text, name: Text = "model_process", wait_time: flo
if wait_time < 60:
wait_time *= 1.1
except Exception as e:
response_body = {"status": "FAILED", "completed": False, "error_message": "No response from the service."}
response_body = ModelResponse(
status=ResponseStatus.FAILED, completed=False, error_message="No response from the service."
)
logging.error(f"Polling for Model: polling for {name}: {e}")
break
if response_body["completed"] is True:
logging.debug(f"Polling for Model: Final status of polling for {name}: {response_body}")
else:
response_body["status"] = "FAILED"
response_body = ModelResponse(
status=ResponseStatus.FAILED, completed=False, error_message="No response from the service."
)
logging.error(
f"Polling for Model: Final status of polling for {name}: No response in {timeout} seconds - {response_body}"
)
Expand All @@ -174,11 +180,11 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> ModelResponse:
try:
resp = r.json()
if resp["completed"] is True:
status = ModelStatus.SUCCESS
status = ResponseStatus.SUCCESS
if "error_message" in resp or "supplierError" in resp:
status = ModelStatus.FAILED
status = ResponseStatus.FAILED
else:
status = ModelStatus.IN_PROGRESS
status = ResponseStatus.IN_PROGRESS
logging.debug(f"Single Poll for Model: Status of polling for {name}: {resp}")
return ModelResponse(
status=resp.pop("status", status),
Expand All @@ -195,7 +201,7 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> ModelResponse:
resp = {"status": "FAILED"}
logging.error(f"Single Poll for Model: Error of polling for {name}: {e}")
return ModelResponse(
status=ModelStatus.FAILED,
status=ResponseStatus.FAILED,
error_message=str(e),
completed=False,
)
Expand Down Expand Up @@ -234,9 +240,9 @@ def run(
msg = f"Error in request for {name} - {traceback.format_exc()}"
logging.error(f"Model Run: Error in running for {name}: {e}")
end = time.time()
response = {"status": "FAILED", "error": msg, "runTime": end - start}
response = {"status": "FAILED", "error_message": msg, "runTime": end - start}
return ModelResponse(
status=response.pop("status", ModelStatus.FAILED),
status=response.pop("status", ResponseStatus.FAILED),
data=response.pop("data", ""),
details=response.pop("details", {}),
completed=response.pop("completed", False),
Expand All @@ -247,7 +253,9 @@ def run(
**response,
)

def run_async(self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = {}) -> ModelResponse:
def run_async(
self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = {}
) -> ModelResponse:
"""Runs asynchronously a model call.

Args:
Expand All @@ -263,7 +271,7 @@ def run_async(self, data: Union[Text, Dict], name: Text = "model_process", param
payload = build_payload(data=data, parameters=parameters)
response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key)
return ModelResponse(
status=response.pop("status", ModelStatus.FAILED),
status=response.pop("status", ResponseStatus.FAILED),
data=response.pop("data", ""),
details=response.pop("details", {}),
completed=response.pop("completed", False),
Expand Down
6 changes: 3 additions & 3 deletions aixplain/modules/model/llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from aixplain.utils import config
from typing import Union, Optional, List, Text, Dict
from aixplain.modules.model.response import ModelResponse
from aixplain.enums import ModelStatus
from aixplain.enums.response_status import ResponseStatus


class LLM(Model):
Expand Down Expand Up @@ -152,7 +152,7 @@ def run(
end = time.time()
response = {"status": "FAILED", "error": msg, "elapsed_time": end - start}
return ModelResponse(
status=response.pop("status", ModelStatus.FAILED),
status=response.pop("status", ResponseStatus.FAILED),
data=response.pop("data", ""),
details=response.pop("details", {}),
completed=response.pop("completed", False),
Expand Down Expand Up @@ -206,7 +206,7 @@ def run_async(
payload = build_payload(data=data, parameters=parameters)
response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key)
return ModelResponse(
status=response.pop("status", ModelStatus.FAILED),
status=response.pop("status", ResponseStatus.FAILED),
data=response.pop("data", ""),
details=response.pop("details", {}),
completed=response.pop("completed", False),
Expand Down
6 changes: 2 additions & 4 deletions aixplain/modules/model/response.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from dataclasses import dataclass
from typing import Text, Any, Optional, Dict, List, Union
from aixplain.enums import ModelStatus
from aixplain.enums import ResponseStatus


@dataclass
class ModelResponse:
"""ModelResponse class to store the response of the model run."""

def __init__(
self,
status: ModelStatus,
status: ResponseStatus,
data: Text = "",
details: Optional[Union[Dict, List]] = {},
completed: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/llm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

load_dotenv()
from aixplain.utils import config
from aixplain.enums import ModelStatus
from aixplain.enums import ResponseStatus
from aixplain.modules.model.response import ModelResponse
from aixplain.modules import LLM

Expand Down Expand Up @@ -85,7 +85,7 @@ def test_run_sync():
response = test_model.run(data=input_data, temperature=0.001, max_tokens=128, top_p=1.0)

assert isinstance(response, ModelResponse)
assert response.status == ModelStatus.SUCCESS
assert response.status == ResponseStatus.SUCCESS
assert response.data == "Test Model Result"
assert response.completed is True
assert response.used_credits == 0
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from aixplain.factories import ModelFactory
from aixplain.enums import Function
from urllib.parse import urljoin
from aixplain.enums import ModelStatus
from aixplain.enums import ResponseStatus
from aixplain.modules.model.response import ModelResponse
import pytest
from unittest.mock import patch
Expand Down Expand Up @@ -67,7 +67,7 @@ def test_call_run_endpoint_sync():
model_id = "model-id"
execute_url = f"{base_url}/{model_id}".replace("/api/v1/execute", "/api/v2/execute")
payload = {"data": "input_data"}
ref_response = {"completed": True, "status": ModelStatus.SUCCESS, "data": "Hello"}
ref_response = {"completed": True, "status": ResponseStatus.SUCCESS, "data": "Hello"}

with requests_mock.Mocker() as mock:
mock.post(execute_url, json=ref_response)
Expand All @@ -88,7 +88,7 @@ def test_success_poll():
hyp_response = test_model.poll(poll_url=poll_url)
assert isinstance(hyp_response, ModelResponse)
assert hyp_response["completed"] == ref_response["completed"]
assert hyp_response["status"] == ModelStatus.SUCCESS
assert hyp_response["status"] == ResponseStatus.SUCCESS


def test_failed_poll():
Expand All @@ -103,7 +103,7 @@ def test_failed_poll():
response = model.poll(poll_url=poll_url)

assert isinstance(response, ModelResponse)
assert response.status == ModelStatus.FAILED
assert response.status == ResponseStatus.FAILED
assert response.error_message == "Some error occurred"
assert response.completed is True

Expand Down Expand Up @@ -145,7 +145,7 @@ def test_run_async_errors(status_code, error_message):
test_model = Model(id=model_id, name="Test Model", url=base_url)
response = test_model.run_async(data="input_data")
assert isinstance(response, ModelResponse)
assert response["status"] == ModelStatus.FAILED
assert response["status"] == ResponseStatus.FAILED
assert response["error_message"] == error_message


Expand Down Expand Up @@ -219,7 +219,7 @@ def test_run_sync():
response = test_model.run(data=input_data, name="test_run")

assert isinstance(response, ModelResponse)
assert response.status == ModelStatus.SUCCESS
assert response.status == ResponseStatus.SUCCESS
assert response.data == "Test Model Result"
assert response.completed is True
assert response.used_credits == 0
Expand Down