Skip to content
89 changes: 20 additions & 69 deletions aixplain/modules/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
Model Class
"""
import time
import json
import logging
import traceback
from aixplain.enums import Supplier, Function
from aixplain.modules.asset import Asset
from aixplain.modules.model.utils import build_payload, call_run_endpoint
from aixplain.utils import config
from urllib.parse import urljoin
from aixplain.utils.file_utils import _request_with_retry
Expand Down Expand Up @@ -149,7 +149,7 @@ def sync_poll(self, poll_url: Text, name: Text = "model_process", wait_time: flo
logging.error(f"Polling for Model: polling for {name}: {e}")
break
if response_body["completed"] is True:
logging.info(f"Polling for Model: Final status of polling for {name}: {response_body}")
logging.debug(f"Polling for Model: Final status of polling for {name}: {response_body}")
else:
response_body["status"] = "FAILED"
logging.error(
Expand Down Expand Up @@ -204,21 +204,21 @@ def run(
Dict: parsed output from model
"""
start = time.time()
try:
response = self.run_async(data, name=name, parameters=parameters)
if response["status"] == "FAILED":
payload = build_payload(data=data, parameters=parameters)
url = f"{self.url}/api/v2/execute/{self.id}"
logging.debug(f"Model Run Sync: Start service for {name} - {url}")
response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key)
if response["status"] == "IN_PROGRESS":
try:
poll_url = response["url"]
end = time.time()
response["elapsed_time"] = end - start
return response
poll_url = response["url"]
end = time.time()
response = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time)
return response
except Exception as e:
msg = f"Error in request for {name} - {traceback.format_exc()}"
logging.error(f"Model Run: Error in running for {name}: {e}")
end = time.time()
return {"status": "FAILED", "error": msg, "elapsed_time": end - start}
response = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time)
except Exception as e:
msg = f"Error in request for {name} - {traceback.format_exc()}"
logging.error(f"Model Run: Error in running for {name}: {e}")
end = time.time()
response = {"status": "FAILED", "error": msg, "elapsed_time": end - start}
return response

def run_async(self, data: Union[Text, Dict], name: Text = "model_process", parameters: Dict = {}) -> Dict:
"""Runs asynchronously a model call.
Expand All @@ -231,59 +231,10 @@ def run_async(self, data: Union[Text, Dict], name: Text = "model_process", param
Returns:
dict: polling URL in response
"""
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}
from aixplain.factories.file_factory import FileFactory

data = FileFactory.to_link(data)
if isinstance(data, dict):
payload = data
else:
try:
payload = json.loads(data)
if isinstance(payload, dict) is False:
if isinstance(payload, int) is True or isinstance(payload, float) is True:
payload = str(payload)
payload = {"data": payload}
except Exception:
payload = {"data": data}
payload.update(parameters)
payload = json.dumps(payload)

call_url = f"{self.url}/{self.id}"
r = _request_with_retry("post", call_url, headers=headers, data=payload)
logging.info(f"Model Run Async: Start service for {name} - {self.url} - {payload} - {headers}")

resp = None
try:
if 200 <= r.status_code < 300:
resp = r.json()
logging.info(f"Result of request for {name} - {r.status_code} - {resp}")
poll_url = resp["data"]
response = {"status": "IN_PROGRESS", "url": poll_url}
else:
if r.status_code == 401:
error = "Unauthorized API key: Please verify the spelling of the API key and its current validity."
elif 460 <= r.status_code < 470:
error = "Subscription-related error: Please ensure that your subscription is active and has not expired."
elif 470 <= r.status_code < 480:
error = "Billing-related error: Please ensure you have enough credits to run this model. "
elif 480 <= r.status_code < 490:
error = "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access."
elif 490 <= r.status_code < 500:
error = "Validation-related error: Please ensure all required fields are provided and correctly formatted."
else:
status_code = str(r.status_code)
error = (
f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request."
)
response = {"status": "FAILED", "error_message": error}
logging.error(f"Error in request for {name} - {r.status_code}: {error}")
except Exception:
response = {"status": "FAILED"}
msg = f"Error in request for {name} - {traceback.format_exc()}"
logging.error(f"Model Run Async: Error in running for {name}: {resp}")
if resp is not None:
response["error"] = msg
url = f"{self.url}/api/v1/execute/{self.id}"
logging.debug(f"Model Run Async: Start service for {name} - {url}")
payload = build_payload(data=data, parameters=parameters)
response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key)
return response

def check_finetune_status(self, after_epoch: Optional[int] = None):
Expand Down
119 changes: 35 additions & 84 deletions aixplain/modules/model/llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@
Large Language Model Class
"""
import time
import json
import logging
import traceback
from aixplain.enums import Function, Supplier
from aixplain.modules.model import Model
from aixplain.modules.model.utils import build_payload, call_run_endpoint
from aixplain.utils import config
from aixplain.utils.file_utils import _request_with_retry
from typing import Union, Optional, List, Text, Dict


Expand Down Expand Up @@ -125,31 +124,31 @@ def run(
Dict: parsed output from model
"""
start = time.time()
try:
response = self.run_async(
data,
name=name,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
context=context,
prompt=prompt,
history=history,
parameters=parameters,
)
if response["status"] == "FAILED":
parameters.update(
{
"context": parameters["context"] if "context" in parameters else context,
"prompt": parameters["prompt"] if "prompt" in parameters else prompt,
"history": parameters["history"] if "history" in parameters else history,
"temperature": parameters["temperature"] if "temperature" in parameters else temperature,
"max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens,
"top_p": parameters["top_p"] if "top_p" in parameters else top_p,
}
)
payload = build_payload(data=data, parameters=parameters)
url = f"{self.url}/api/v2/execute/{self.id}"
logging.debug(f"Model Run Sync: Start service for {name} - {url}")
response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key)
if response["status"] == "IN_PROGRESS":
try:
poll_url = response["url"]
end = time.time()
response["elapsed_time"] = end - start
return response
poll_url = response["url"]
end = time.time()
response = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time)
return response
except Exception as e:
msg = f"Error in request for {name} - {traceback.format_exc()}"
logging.error(f"LLM Run: Error in running for {name}: {e}")
end = time.time()
return {"status": "FAILED", "error": msg, "elapsed_time": end - start}
response = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time)
except Exception as e:
msg = f"Error in request for {name} - {traceback.format_exc()}"
logging.error(f"Model Run: Error in running for {name}: {e}")
end = time.time()
response = {"status": "FAILED", "error": msg, "elapsed_time": end - start}
return response

def run_async(
self,
Expand Down Expand Up @@ -179,66 +178,18 @@ def run_async(
Returns:
dict: polling URL in response
"""
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}

from aixplain.factories.file_factory import FileFactory

data = FileFactory.to_link(data)
if isinstance(data, dict):
payload = data
else:
try:
payload = json.loads(data)
if isinstance(payload, dict) is False:
if isinstance(payload, int) is True or isinstance(payload, float) is True:
payload = str(payload)
payload = {"data": payload}
except Exception:
payload = {"data": data}
url = f"{self.url}/api/v1/execute/{self.id}"
logging.debug(f"Model Run Async: Start service for {name} - {url}")
parameters.update(
{
"context": payload["context"] if "context" in payload else context,
"prompt": payload["prompt"] if "prompt" in payload else prompt,
"history": payload["history"] if "history" in payload else history,
"temperature": payload["temperature"] if "temperature" in payload else temperature,
"max_tokens": payload["max_tokens"] if "max_tokens" in payload else max_tokens,
"top_p": payload["top_p"] if "top_p" in payload else top_p,
"context": parameters["context"] if "context" in parameters else context,
"prompt": parameters["prompt"] if "prompt" in parameters else prompt,
"history": parameters["history"] if "history" in parameters else history,
"temperature": parameters["temperature"] if "temperature" in parameters else temperature,
"max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens,
"top_p": parameters["top_p"] if "top_p" in parameters else top_p,
}
)
payload.update(parameters)
payload = json.dumps(payload)

call_url = f"{self.url}/{self.id}"
r = _request_with_retry("post", call_url, headers=headers, data=payload)
logging.info(f"Model Run Async: Start service for {name} - {self.url} - {payload} - {headers}")

resp = None
try:
if 200 <= r.status_code < 300:
resp = r.json()
logging.info(f"Result of request for {name} - {r.status_code} - {resp}")
poll_url = resp["data"]
response = {"status": "IN_PROGRESS", "url": poll_url}
else:
if r.status_code == 401:
error = "Unauthorized API key: Please verify the spelling of the API key and its current validity."
elif 460 <= r.status_code < 470:
error = "Subscription-related error: Please ensure that your subscription is active and has not expired."
elif 470 <= r.status_code < 480:
error = "Billing-related error: Please ensure you have enough credits to run this model. "
elif 480 <= r.status_code < 490:
error = "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access."
elif 490 <= r.status_code < 500:
error = "Validation-related error: Please ensure all required fields are provided and correctly formatted."
else:
status_code = str(r.status_code)
error = f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request."
response = {"status": "FAILED", "error_message": error}
logging.error(f"Error in request for {name} - {r.status_code}: {error}")
except Exception:
response = {"status": "FAILED"}
msg = f"Error in request for {name} - {traceback.format_exc()}"
logging.error(f"Model Run Async: Error in running for {name}: {resp}")
if resp is not None:
response["error"] = msg
payload = build_payload(data=data, parameters=parameters)
response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key)
return response
75 changes: 75 additions & 0 deletions aixplain/modules/model/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
__author__ = "thiagocastroferreira"

import json
import logging
from aixplain.utils.file_utils import _request_with_retry
from typing import Dict, Text, Union


def build_payload(data: Union[Text, Dict], parameters: Dict = {}):
from aixplain.factories import FileFactory

data = FileFactory.to_link(data)
if isinstance(data, dict):
payload = data
else:
try:
payload = json.loads(data)
if isinstance(payload, dict) is False:
if isinstance(payload, int) is True or isinstance(payload, float) is True:
payload = str(payload)
payload = {"data": payload}
except Exception:
payload = {"data": data}
payload.update(parameters)
payload = json.dumps(payload)
return payload


def call_run_endpoint(url: Text, api_key: Text, payload: Dict) -> Dict:
headers = {"x-api-key": api_key, "Content-Type": "application/json"}

resp = "unspecified error"
try:
r = _request_with_retry("post", url, headers=headers, data=payload)
resp = r.json()
except Exception as e:
logging.error(f"Error in request: {e}")
response = {
"status": "FAILED",
"completed": True,
"error_message": "Model Run: An error occurred while processing your request.",
}

if 200 <= r.status_code < 300:
logging.info(f"Result of request: {r.status_code} - {resp}")
status = resp.get("status", "IN_PROGRESS")
data = resp.get("data", None)
if status == "IN_PROGRESS":
if data is not None:
response = {"status": status, "url": data, "completed": True}
else:
response = {
"status": "FAILED",
"completed": True,
"error_message": "Model Run: An error occurred while processing your request.",
}
else:
response = {"status": status, "data": data, "completed": True}
else:
if r.status_code == 401:
error = f"Unauthorized API key: Please verify the spelling of the API key and its current validity. Details: {resp}"
elif 460 <= r.status_code < 470:
error = f"Subscription-related error: Please ensure that your subscription is active and has not expired. Details: {resp}"
elif 470 <= r.status_code < 480:
error = f"Billing-related error: Please ensure you have enough credits to run this model. Details: {resp}"
elif 480 <= r.status_code < 490:
error = f"Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: {resp}"
elif 490 <= r.status_code < 500:
error = f"Validation-related error: Please ensure all required fields are provided and correctly formatted. Details: {resp}"
else:
status_code = str(r.status_code)
error = f"Status {status_code} - Unspecified error: {resp}"
response = {"status": "FAILED", "error_message": error, "completed": True}
logging.error(f"Error in request: {r.status_code}: {error}")
return response
4 changes: 2 additions & 2 deletions aixplain/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
logger = logging.getLogger(__name__)

BACKEND_URL = os.getenv("BACKEND_URL", "https://platform-api.aixplain.com")
MODELS_RUN_URL = os.getenv("MODELS_RUN_URL", "https://models.aixplain.com/api/v1/execute")
MODELS_RUN_URL = os.getenv("MODELS_RUN_URL", "https://models.aixplain.com")
# GET THE API KEY FROM CMD
TEAM_API_KEY = os.getenv("TEAM_API_KEY", "")
AIXPLAIN_API_KEY = os.getenv("AIXPLAIN_API_KEY", "")
PIPELINE_API_KEY = os.getenv("PIPELINE_API_KEY", "")
MODEL_API_KEY = os.getenv("MODEL_API_KEY", "")
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
HF_TOKEN = os.getenv("HF_TOKEN", "")
HF_TOKEN = os.getenv("HF_TOKEN", "")
4 changes: 4 additions & 0 deletions tests/functional/general_assets/data/asset_run_test_data.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
"id" : "61b097551efecf30109d32da",
"data": "This is a test sentence."
},
"model2" : {
"id" : "60ddefab8d38c51c5885ee38",
"data": "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/myname.mp3"
},
"pipeline": {
"name": "SingleNodePipeline",
"data": "This is a test sentence."
Expand Down
12 changes: 12 additions & 0 deletions tests/functional/model/run_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,15 @@ def test_llm_run(llm_model):
)
assert response["status"] == "SUCCESS"
assert "thiago" in response["data"].lower()


def test_run_async():
"""Testing Model Async"""
model = ModelFactory.get("60ddef828d38c51c5885d491")

response = model.run_async("Test")
poll_url = response["url"]
response = model.sync_poll(poll_url)

assert response["status"] == "SUCCESS"
assert "teste" in response["data"].lower()
Loading