diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 2e9445b5..621eb522 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -21,11 +21,11 @@ Model Class """ import time -import json import logging import traceback from aixplain.enums import Supplier, Function from aixplain.modules.asset import Asset +from aixplain.modules.model.utils import build_payload, call_run_endpoint from aixplain.utils import config from urllib.parse import urljoin from aixplain.utils.file_utils import _request_with_retry @@ -149,7 +149,7 @@ def sync_poll(self, poll_url: Text, name: Text = "model_process", wait_time: flo logging.error(f"Polling for Model: polling for {name}: {e}") break if response_body["completed"] is True: - logging.info(f"Polling for Model: Final status of polling for {name}: {response_body}") + logging.debug(f"Polling for Model: Final status of polling for {name}: {response_body}") else: response_body["status"] = "FAILED" logging.error( @@ -204,21 +204,21 @@ def run( Dict: parsed output from model """ start = time.time() - try: - response = self.run_async(data, name=name, parameters=parameters) - if response["status"] == "FAILED": + payload = build_payload(data=data, parameters=parameters) + url = f"{self.url}/api/v2/execute/{self.id}" + logging.debug(f"Model Run Sync: Start service for {name} - {url}") + response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key) + if response["status"] == "IN_PROGRESS": + try: + poll_url = response["url"] end = time.time() - response["elapsed_time"] = end - start - return response - poll_url = response["url"] - end = time.time() - response = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) - return response - except Exception as e: - msg = f"Error in request for {name} - {traceback.format_exc()}" - logging.error(f"Model Run: Error in running for {name}: {e}") - end = time.time() - return {"status": "FAILED", "error": msg, "elapsed_time": end - start} + response = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) + except Exception as e: + msg = f"Error in request for {name} - {traceback.format_exc()}" + logging.error(f"Model Run: Error in running for {name}: {e}") + end = time.time() + response = {"status": "FAILED", "error": msg, "elapsed_time": end - start} + return response def run_async(self, data: Union[Text, Dict], name: Text = "model_process", parameters: Dict = {}) -> Dict: """Runs asynchronously a model call. @@ -231,59 +231,10 @@ def run_async(self, data: Union[Text, Dict], name: Text = "model_process", param Returns: dict: polling URL in response """ - headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} - from aixplain.factories.file_factory import FileFactory - - data = FileFactory.to_link(data) - if isinstance(data, dict): - payload = data - else: - try: - payload = json.loads(data) - if isinstance(payload, dict) is False: - if isinstance(payload, int) is True or isinstance(payload, float) is True: - payload = str(payload) - payload = {"data": payload} - except Exception: - payload = {"data": data} - payload.update(parameters) - payload = json.dumps(payload) - - call_url = f"{self.url}/{self.id}" - r = _request_with_retry("post", call_url, headers=headers, data=payload) - logging.info(f"Model Run Async: Start service for {name} - {self.url} - {payload} - {headers}") - - resp = None - try: - if 200 <= r.status_code < 300: - resp = r.json() - logging.info(f"Result of request for {name} - {r.status_code} - {resp}") - poll_url = resp["data"] - response = {"status": "IN_PROGRESS", "url": poll_url} - else: - if r.status_code == 401: - error = "Unauthorized API key: Please verify the spelling of the API key and its current validity." - elif 460 <= r.status_code < 470: - error = "Subscription-related error: Please ensure that your subscription is active and has not expired." - elif 470 <= r.status_code < 480: - error = "Billing-related error: Please ensure you have enough credits to run this model. " - elif 480 <= r.status_code < 490: - error = "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access." - elif 490 <= r.status_code < 500: - error = "Validation-related error: Please ensure all required fields are provided and correctly formatted." - else: - status_code = str(r.status_code) - error = ( - f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." - ) - response = {"status": "FAILED", "error_message": error} - logging.error(f"Error in request for {name} - {r.status_code}: {error}") - except Exception: - response = {"status": "FAILED"} - msg = f"Error in request for {name} - {traceback.format_exc()}" - logging.error(f"Model Run Async: Error in running for {name}: {resp}") - if resp is not None: - response["error"] = msg + url = f"{self.url}/api/v1/execute/{self.id}" + logging.debug(f"Model Run Async: Start service for {name} - {url}") + payload = build_payload(data=data, parameters=parameters) + response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key) return response def check_finetune_status(self, after_epoch: Optional[int] = None): diff --git a/aixplain/modules/model/llm_model.py b/aixplain/modules/model/llm_model.py index c595d207..84db6704 100644 --- a/aixplain/modules/model/llm_model.py +++ b/aixplain/modules/model/llm_model.py @@ -21,13 +21,12 @@ Large Language Model Class """ import time -import json import logging import traceback from aixplain.enums import Function, Supplier from aixplain.modules.model import Model +from aixplain.modules.model.utils import build_payload, call_run_endpoint from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry from typing import Union, Optional, List, Text, Dict @@ -125,31 +124,31 @@ def run( Dict: parsed output from model """ start = time.time() - try: - response = self.run_async( - data, - name=name, - temperature=temperature, - max_tokens=max_tokens, - top_p=top_p, - context=context, - prompt=prompt, - history=history, - parameters=parameters, - ) - if response["status"] == "FAILED": + parameters.update( + { + "context": parameters["context"] if "context" in parameters else context, + "prompt": parameters["prompt"] if "prompt" in parameters else prompt, + "history": parameters["history"] if "history" in parameters else history, + "temperature": parameters["temperature"] if "temperature" in parameters else temperature, + "max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens, + "top_p": parameters["top_p"] if "top_p" in parameters else top_p, + } + ) + payload = build_payload(data=data, parameters=parameters) + url = f"{self.url}/api/v2/execute/{self.id}" + logging.debug(f"Model Run Sync: Start service for {name} - {url}") + response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key) + if response["status"] == "IN_PROGRESS": + try: + poll_url = response["url"] end = time.time() - response["elapsed_time"] = end - start - return response - poll_url = response["url"] - end = time.time() - response = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) - return response - except Exception as e: - msg = f"Error in request for {name} - {traceback.format_exc()}" - logging.error(f"LLM Run: Error in running for {name}: {e}") - end = time.time() - return {"status": "FAILED", "error": msg, "elapsed_time": end - start} + response = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) + except Exception as e: + msg = f"Error in request for {name} - {traceback.format_exc()}" + logging.error(f"Model Run: Error in running for {name}: {e}") + end = time.time() + response = {"status": "FAILED", "error": msg, "elapsed_time": end - start} + return response def run_async( self, @@ -179,66 +178,18 @@ def run_async( Returns: dict: polling URL in response """ - headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} - - from aixplain.factories.file_factory import FileFactory - - data = FileFactory.to_link(data) - if isinstance(data, dict): - payload = data - else: - try: - payload = json.loads(data) - if isinstance(payload, dict) is False: - if isinstance(payload, int) is True or isinstance(payload, float) is True: - payload = str(payload) - payload = {"data": payload} - except Exception: - payload = {"data": data} + url = f"{self.url}/api/v1/execute/{self.id}" + logging.debug(f"Model Run Async: Start service for {name} - {url}") parameters.update( { - "context": payload["context"] if "context" in payload else context, - "prompt": payload["prompt"] if "prompt" in payload else prompt, - "history": payload["history"] if "history" in payload else history, - "temperature": payload["temperature"] if "temperature" in payload else temperature, - "max_tokens": payload["max_tokens"] if "max_tokens" in payload else max_tokens, - "top_p": payload["top_p"] if "top_p" in payload else top_p, + "context": parameters["context"] if "context" in parameters else context, + "prompt": parameters["prompt"] if "prompt" in parameters else prompt, + "history": parameters["history"] if "history" in parameters else history, + "temperature": parameters["temperature"] if "temperature" in parameters else temperature, + "max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens, + "top_p": parameters["top_p"] if "top_p" in parameters else top_p, } ) - payload.update(parameters) - payload = json.dumps(payload) - - call_url = f"{self.url}/{self.id}" - r = _request_with_retry("post", call_url, headers=headers, data=payload) - logging.info(f"Model Run Async: Start service for {name} - {self.url} - {payload} - {headers}") - - resp = None - try: - if 200 <= r.status_code < 300: - resp = r.json() - logging.info(f"Result of request for {name} - {r.status_code} - {resp}") - poll_url = resp["data"] - response = {"status": "IN_PROGRESS", "url": poll_url} - else: - if r.status_code == 401: - error = "Unauthorized API key: Please verify the spelling of the API key and its current validity." - elif 460 <= r.status_code < 470: - error = "Subscription-related error: Please ensure that your subscription is active and has not expired." - elif 470 <= r.status_code < 480: - error = "Billing-related error: Please ensure you have enough credits to run this model. " - elif 480 <= r.status_code < 490: - error = "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access." - elif 490 <= r.status_code < 500: - error = "Validation-related error: Please ensure all required fields are provided and correctly formatted." - else: - status_code = str(r.status_code) - error = f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." - response = {"status": "FAILED", "error_message": error} - logging.error(f"Error in request for {name} - {r.status_code}: {error}") - except Exception: - response = {"status": "FAILED"} - msg = f"Error in request for {name} - {traceback.format_exc()}" - logging.error(f"Model Run Async: Error in running for {name}: {resp}") - if resp is not None: - response["error"] = msg + payload = build_payload(data=data, parameters=parameters) + response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key) return response diff --git a/aixplain/modules/model/utils.py b/aixplain/modules/model/utils.py new file mode 100644 index 00000000..a78455b7 --- /dev/null +++ b/aixplain/modules/model/utils.py @@ -0,0 +1,75 @@ +__author__ = "thiagocastroferreira" + +import json +import logging +from aixplain.utils.file_utils import _request_with_retry +from typing import Dict, Text, Union + + +def build_payload(data: Union[Text, Dict], parameters: Dict = {}): + from aixplain.factories import FileFactory + + data = FileFactory.to_link(data) + if isinstance(data, dict): + payload = data + else: + try: + payload = json.loads(data) + if isinstance(payload, dict) is False: + if isinstance(payload, int) is True or isinstance(payload, float) is True: + payload = str(payload) + payload = {"data": payload} + except Exception: + payload = {"data": data} + payload.update(parameters) + payload = json.dumps(payload) + return payload + + +def call_run_endpoint(url: Text, api_key: Text, payload: Dict) -> Dict: + headers = {"x-api-key": api_key, "Content-Type": "application/json"} + + resp = "unspecified error" + try: + r = _request_with_retry("post", url, headers=headers, data=payload) + resp = r.json() + except Exception as e: + logging.error(f"Error in request: {e}") + response = { + "status": "FAILED", + "completed": True, + "error_message": "Model Run: An error occurred while processing your request.", + } + + if 200 <= r.status_code < 300: + logging.info(f"Result of request: {r.status_code} - {resp}") + status = resp.get("status", "IN_PROGRESS") + data = resp.get("data", None) + if status == "IN_PROGRESS": + if data is not None: + response = {"status": status, "url": data, "completed": True} + else: + response = { + "status": "FAILED", + "completed": True, + "error_message": "Model Run: An error occurred while processing your request.", + } + else: + response = {"status": status, "data": data, "completed": True} + else: + if r.status_code == 401: + error = f"Unauthorized API key: Please verify the spelling of the API key and its current validity. Details: {resp}" + elif 460 <= r.status_code < 470: + error = f"Subscription-related error: Please ensure that your subscription is active and has not expired. Details: {resp}" + elif 470 <= r.status_code < 480: + error = f"Billing-related error: Please ensure you have enough credits to run this model. Details: {resp}" + elif 480 <= r.status_code < 490: + error = f"Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: {resp}" + elif 490 <= r.status_code < 500: + error = f"Validation-related error: Please ensure all required fields are provided and correctly formatted. Details: {resp}" + else: + status_code = str(r.status_code) + error = f"Status {status_code} - Unspecified error: {resp}" + response = {"status": "FAILED", "error_message": error, "completed": True} + logging.error(f"Error in request: {r.status_code}: {error}") + return response diff --git a/aixplain/utils/config.py b/aixplain/utils/config.py index 3bb0eb09..59805c60 100644 --- a/aixplain/utils/config.py +++ b/aixplain/utils/config.py @@ -19,11 +19,11 @@ logger = logging.getLogger(__name__) BACKEND_URL = os.getenv("BACKEND_URL", "https://platform-api.aixplain.com") -MODELS_RUN_URL = os.getenv("MODELS_RUN_URL", "https://models.aixplain.com/api/v1/execute") +MODELS_RUN_URL = os.getenv("MODELS_RUN_URL", "https://models.aixplain.com") # GET THE API KEY FROM CMD TEAM_API_KEY = os.getenv("TEAM_API_KEY", "") AIXPLAIN_API_KEY = os.getenv("AIXPLAIN_API_KEY", "") PIPELINE_API_KEY = os.getenv("PIPELINE_API_KEY", "") MODEL_API_KEY = os.getenv("MODEL_API_KEY", "") LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") -HF_TOKEN = os.getenv("HF_TOKEN", "") \ No newline at end of file +HF_TOKEN = os.getenv("HF_TOKEN", "") diff --git a/tests/functional/general_assets/data/asset_run_test_data.json b/tests/functional/general_assets/data/asset_run_test_data.json index abe7a3e9..e24df1ef 100644 --- a/tests/functional/general_assets/data/asset_run_test_data.json +++ b/tests/functional/general_assets/data/asset_run_test_data.json @@ -3,6 +3,10 @@ "id" : "61b097551efecf30109d32da", "data": "This is a test sentence." }, + "model2" : { + "id" : "60ddefab8d38c51c5885ee38", + "data": "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/myname.mp3" + }, "pipeline": { "name": "SingleNodePipeline", "data": "This is a test sentence." diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 47f351bb..d5c1d6ac 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -31,3 +31,15 @@ def test_llm_run(llm_model): ) assert response["status"] == "SUCCESS" assert "thiago" in response["data"].lower() + + +def test_run_async(): + """Testing Model Async""" + model = ModelFactory.get("60ddef828d38c51c5885d491") + + response = model.run_async("Test") + poll_url = response["url"] + response = model.sync_poll(poll_url) + + assert response["status"] == "SUCCESS" + assert "teste" in response["data"].lower() diff --git a/tests/unit/llm_test.py b/tests/unit/llm_test.py index 430fc338..f76f71b2 100644 --- a/tests/unit/llm_test.py +++ b/tests/unit/llm_test.py @@ -1,6 +1,4 @@ - from dotenv import load_dotenv -from urllib.parse import urljoin import requests_mock from aixplain.enums import Function @@ -10,27 +8,44 @@ import pytest + @pytest.mark.parametrize( "status_code,error_message", [ - (401,"Unauthorized API key: Please verify the spelling of the API key and its current validity."), - (465,"Subscription-related error: Please ensure that your subscription is active and has not expired."), - (475,"Billing-related error: Please ensure you have enough credits to run this model. "), - (485, "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access."), - (495, "Validation-related error: Please ensure all required fields are provided and correctly formatted."), - (501, "Status 501: Unspecified error: An unspecified error occurred while processing your request."), - + ( + 401, + "Unauthorized API key: Please verify the spelling of the API key and its current validity. Details: {'error': 'An unspecified error occurred while processing your request.'}", + ), + ( + 465, + "Subscription-related error: Please ensure that your subscription is active and has not expired. Details: {'error': 'An unspecified error occurred while processing your request.'}", + ), + ( + 475, + "Billing-related error: Please ensure you have enough credits to run this model. Details: {'error': 'An unspecified error occurred while processing your request.'}", + ), + ( + 485, + "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: {'error': 'An unspecified error occurred while processing your request.'}", + ), + ( + 495, + "Validation-related error: Please ensure all required fields are provided and correctly formatted. Details: {'error': 'An unspecified error occurred while processing your request.'}", + ), + (501, "Status 501 - Unspecified error: {'error': 'An unspecified error occurred while processing your request.'}"), ], ) - def test_run_async_errors(status_code, error_message): base_url = config.MODELS_RUN_URL llm_id = "llm-id" - execute_url = urljoin(base_url, f"execute/{llm_id}") - + execute_url = f"{base_url}/api/v1/execute/{llm_id}" + ref_response = { + "error": "An unspecified error occurred while processing your request.", + } + with requests_mock.Mocker() as mock: - mock.post(execute_url, status_code=status_code) - test_llm = LLM(id=llm_id, name="Test llm",url=base_url, function=Function.TEXT_GENERATION) + mock.post(execute_url, status_code=status_code, json=ref_response) + test_llm = LLM(id=llm_id, name="Test llm", url=base_url, function=Function.TEXT_GENERATION) response = test_llm.run_async(data="input_data") assert response["status"] == "FAILED" - assert response["error_message"] == error_message \ No newline at end of file + assert response["error_message"] == error_message diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index a319742c..d491c1fd 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -17,19 +17,67 @@ """ from dotenv import load_dotenv -from urllib.parse import urljoin import requests_mock load_dotenv() import re +import json from aixplain.utils import config from aixplain.modules import Model +from aixplain.modules.model.utils import build_payload, call_run_endpoint from aixplain.factories import ModelFactory from aixplain.enums import Function +from urllib.parse import urljoin import pytest +def test_build_payload(): + data = "input_data" + parameters = {"context": "context_data"} + ref_payload = json.dumps({"data": data, **parameters}) + hyp_payload = build_payload(data, parameters) + assert hyp_payload == ref_payload + + +def test_call_run_endpoint_async(): + base_url = config.MODELS_RUN_URL + model_id = "model-id" + execute_url = f"{base_url}/api/v1/execute/{model_id}" + payload = {"data": "input_data"} + ref_response = { + "completed": True, + "status": "IN_PROGRESS", + "data": "https://models.aixplain.com/api/v1/data/a90c2078-edfe-403f-acba-d2d94cf71f42", + } + + with requests_mock.Mocker() as mock: + mock.post(execute_url, json=ref_response) + response = call_run_endpoint(url=execute_url, api_key=config.TEAM_API_KEY, payload=payload) + + print(response) + assert response["completed"] == ref_response["completed"] + assert response["status"] == ref_response["status"] + assert response["url"] == ref_response["data"] + + +def test_call_run_endpoint_sync(): + base_url = config.MODELS_RUN_URL + model_id = "model-id" + execute_url = f"{base_url}/api/v1/execute/{model_id}" + payload = {"data": "input_data"} + ref_response = {"completed": True, "status": "SUCCESS", "data": "Hello"} + + with requests_mock.Mocker() as mock: + mock.post(execute_url, json=ref_response) + response = call_run_endpoint(url=execute_url, api_key=config.TEAM_API_KEY, payload=payload) + + print(response) + assert response["completed"] == ref_response["completed"] + assert response["status"] == ref_response["status"] + assert response["data"] == ref_response["data"] + + def test_success_poll(): with requests_mock.Mocker() as mock: poll_url = "https://models.aixplain.com/api/v1/data/a90c2078-edfe-403f-acba-d2d94cf71f42" @@ -65,21 +113,39 @@ def test_failed_poll(): @pytest.mark.parametrize( "status_code,error_message", [ - (401, "Unauthorized API key: Please verify the spelling of the API key and its current validity."), - (465, "Subscription-related error: Please ensure that your subscription is active and has not expired."), - (475, "Billing-related error: Please ensure you have enough credits to run this model. "), - (485, "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access."), - (495, "Validation-related error: Please ensure all required fields are provided and correctly formatted."), - (501, "Status 501: Unspecified error: An unspecified error occurred while processing your request."), + ( + 401, + "Unauthorized API key: Please verify the spelling of the API key and its current validity. Details: {'error': 'An unspecified error occurred while processing your request.'}", + ), + ( + 465, + "Subscription-related error: Please ensure that your subscription is active and has not expired. Details: {'error': 'An unspecified error occurred while processing your request.'}", + ), + ( + 475, + "Billing-related error: Please ensure you have enough credits to run this model. Details: {'error': 'An unspecified error occurred while processing your request.'}", + ), + ( + 485, + "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: {'error': 'An unspecified error occurred while processing your request.'}", + ), + ( + 495, + "Validation-related error: Please ensure all required fields are provided and correctly formatted. Details: {'error': 'An unspecified error occurred while processing your request.'}", + ), + (501, "Status 501 - Unspecified error: {'error': 'An unspecified error occurred while processing your request.'}"), ], ) def test_run_async_errors(status_code, error_message): base_url = config.MODELS_RUN_URL model_id = "model-id" - execute_url = urljoin(base_url, f"execute/{model_id}") + execute_url = f"{base_url}/api/v1/execute/{model_id}" + ref_response = { + "error": "An unspecified error occurred while processing your request.", + } with requests_mock.Mocker() as mock: - mock.post(execute_url, status_code=status_code) + mock.post(execute_url, status_code=status_code, json=ref_response) test_model = Model(id=model_id, name="Test Model", url=base_url) response = test_model.run_async(data="input_data") assert response["status"] == "FAILED"