Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ repos:
hooks:
- id: pytest-check
name: pytest-check
entry: coverage run -m pytest tests/unit
entry: coverage run --source=. -m pytest tests/unit
language: python
pass_filenames: false
types: [python]
Expand Down
1 change: 1 addition & 0 deletions aixplain/enums/supplier.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def load_suppliers():
headers = {"x-aixplain-key": aixplain_key, "Content-Type": "application/json"}
else:
headers = {"x-api-key": api_key, "Content-Type": "application/json"}
logging.debug(f"Start service for GET API Creation - {url} - {headers}")
r = _request_with_retry("get", url, headers=headers)
if not 200 <= r.status_code < 300:
raise Exception(
Expand Down
9 changes: 5 additions & 4 deletions aixplain/modules/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def to_dict(self) -> Dict:
return {
"id": self.id,
"name": self.name,
"description": self.description,
"supplier": self.supplier,
"additional_info": clean_additional_info,
"input_params": self.input_params,
Expand Down Expand Up @@ -211,7 +212,7 @@ def run(
data: Union[Text, Dict],
name: Text = "model_process",
timeout: float = 300,
parameters: Optional[Dict] = {},
parameters: Optional[Dict] = None,
wait_time: float = 0.5,
) -> ModelResponse:
"""Runs a model call.
Expand All @@ -220,7 +221,7 @@ def run(
data (Union[Text, Dict]): link to the input data
name (Text, optional): ID given to a call. Defaults to "model_process".
timeout (float, optional): total polling time. Defaults to 300.
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
parameters (Dict, optional): optional parameters to the model. Defaults to None.
wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5.

Returns:
Expand Down Expand Up @@ -254,14 +255,14 @@ def run(
)

def run_async(
self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = {}
self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = None
) -> ModelResponse:
"""Runs asynchronously a model call.

Args:
data (Union[Text, Dict]): link to the input data
name (Text, optional): ID given to a call. Defaults to "model_process".
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
parameters (Dict, optional): optional parameters to the model. Defaults to None.

Returns:
dict: polling URL in response
Expand Down
36 changes: 20 additions & 16 deletions aixplain/modules/model/llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def run(
top_p: float = 1.0,
name: Text = "model_process",
timeout: float = 300,
parameters: Optional[Dict] = {},
parameters: Optional[Dict] = None,
wait_time: float = 0.5,
) -> ModelResponse:
"""Synchronously running a Large Language Model (LLM) model.
Expand All @@ -119,21 +119,23 @@ def run(
top_p (float, optional): Top P. Defaults to 1.0.
name (Text, optional): ID given to a call. Defaults to "model_process".
timeout (float, optional): total polling time. Defaults to 300.
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
parameters (Dict, optional): optional parameters to the model. Defaults to None.
wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5.

Returns:
Dict: parsed output from model
"""
start = time.time()
if parameters is None:
parameters = {}
parameters.update(
{
"context": parameters["context"] if "context" in parameters else context,
"prompt": parameters["prompt"] if "prompt" in parameters else prompt,
"history": parameters["history"] if "history" in parameters else history,
"temperature": parameters["temperature"] if "temperature" in parameters else temperature,
"max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens,
"top_p": parameters["top_p"] if "top_p" in parameters else top_p,
"context": parameters.get("context", context),
"prompt": parameters.get("prompt", prompt),
"history": parameters.get("history", history),
"temperature": parameters.get("temperature", temperature),
"max_tokens": parameters.get("max_tokens", max_tokens),
"top_p": parameters.get("top_p", top_p),
}
)
payload = build_payload(data=data, parameters=parameters)
Expand Down Expand Up @@ -173,7 +175,7 @@ def run_async(
max_tokens: int = 128,
top_p: float = 1.0,
name: Text = "model_process",
parameters: Optional[Dict] = {},
parameters: Optional[Dict] = None,
) -> ModelResponse:
"""Runs asynchronously a model call.

Expand All @@ -186,21 +188,23 @@ def run_async(
max_tokens (int, optional): Maximum Generation Tokens. Defaults to 128.
top_p (float, optional): Top P. Defaults to 1.0.
name (Text, optional): ID given to a call. Defaults to "model_process".
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
parameters (Dict, optional): optional parameters to the model. Defaults to None.

Returns:
dict: polling URL in response
"""
url = f"{self.url}/{self.id}"
logging.debug(f"Model Run Async: Start service for {name} - {url}")
if parameters is None:
parameters = {}
parameters.update(
{
"context": parameters["context"] if "context" in parameters else context,
"prompt": parameters["prompt"] if "prompt" in parameters else prompt,
"history": parameters["history"] if "history" in parameters else history,
"temperature": parameters["temperature"] if "temperature" in parameters else temperature,
"max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens,
"top_p": parameters["top_p"] if "top_p" in parameters else top_p,
"context": parameters.get("context", context),
"prompt": parameters.get("prompt", prompt),
"history": parameters.get("history", history),
"temperature": parameters.get("temperature", temperature),
"max_tokens": parameters.get("max_tokens", max_tokens),
"top_p": parameters.get("top_p", top_p),
}
)
payload = build_payload(data=data, parameters=parameters)
Expand Down
7 changes: 5 additions & 2 deletions aixplain/modules/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
import json
import logging
from aixplain.utils.file_utils import _request_with_retry
from typing import Dict, Text, Union
from typing import Dict, Text, Union, Optional


def build_payload(data: Union[Text, Dict], parameters: Dict = {}):
def build_payload(data: Union[Text, Dict], parameters: Optional[Dict] = None):
from aixplain.factories import FileFactory

if parameters is None:
parameters = {}

data = FileFactory.to_link(data)
if isinstance(data, dict):
payload = data
Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from dotenv import load_dotenv

# Load environment variables once for all tests
load_dotenv()
17 changes: 8 additions & 9 deletions tests/functional/agent/agent_functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def read_data(data_path):
def run_input_map(request):
return request.param


@pytest.fixture(scope="function")
def delete_agents_and_team_agents():
for team_agent in TeamAgentFactory.list()["results"]:
Expand Down Expand Up @@ -100,12 +101,8 @@ def test_list_agents():
assert type(agents_result) is list


def test_update_draft_agent(run_input_map):
for team in TeamAgentFactory.list()["results"]:
team.delete()

for agent in AgentFactory.list()["results"]:
agent.delete()
def test_update_draft_agent(run_input_map, delete_agents_and_team_agents):
assert delete_agents_and_team_agents

tools = []
if "model_tools" in run_input_map:
Expand Down Expand Up @@ -137,7 +134,8 @@ def test_update_draft_agent(run_input_map):
agent.delete()


def test_fail_non_existent_llm():
def test_fail_non_existent_llm(delete_agents_and_team_agents):
assert delete_agents_and_team_agents
with pytest.raises(Exception) as exc_info:
AgentFactory.create(
name="Test Agent",
Expand All @@ -147,6 +145,7 @@ def test_fail_non_existent_llm():
)
assert str(exc_info.value) == "Large Language Model with ID 'non_existent_llm' not found."


def test_delete_agent_in_use(delete_agents_and_team_agents):
assert delete_agents_and_team_agents
agent = AgentFactory.create(
Expand All @@ -160,7 +159,7 @@ def test_delete_agent_in_use(delete_agents_and_team_agents):
description="Test description",
use_mentalist_and_inspector=True,
)

with pytest.raises(Exception) as exc_info:
agent.delete()
assert str(exc_info.value) == "Agent Deletion Error (HTTP 403): err.agent_is_in_use."
assert str(exc_info.value) == "Agent Deletion Error (HTTP 403): err.agent_is_in_use."
67 changes: 67 additions & 0 deletions tests/unit/llm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,70 @@ def test_run_sync():
assert response.used_credits == 0
assert response.run_time == 0
assert response.usage is None


@pytest.mark.skip(reason="Need to fix model response")
def test_run_sync_polling_error():
"""Test handling of polling errors in the run method"""
model_id = "test-model-id"
base_url = config.MODELS_RUN_URL
execute_url = f"{base_url}/{model_id}".replace("/api/v1/execute", "/api/v2/execute")

ref_response = {
"status": "IN_PROGRESS",
"data": "https://models.aixplain.com/api/v1/data/invalid-id",
}

with requests_mock.Mocker() as mock:
# Mock the initial execution call
mock.post(execute_url, json=ref_response)

# Mock the polling URL to raise an exception
poll_url = ref_response["data"]
mock.get(poll_url, exc=Exception("Polling failed"))

test_model = LLM(id=model_id, name="Test Model", function=Function.TEXT_GENERATION, url=base_url)

response = test_model.run(data="test input")

# Updated assertions to match ModelResponse structure
assert isinstance(response, ModelResponse)
assert response.status == ResponseStatus.FAILED
assert response.completed is False
assert "No response from the service" in response.error_message
assert response.data == ""
assert response.used_credits == 0
assert response.run_time == 0
assert response.usage is None


def test_run_with_custom_parameters():
"""Test run method with custom parameters"""
model_id = "test-model-id"
base_url = config.MODELS_RUN_URL
execute_url = f"{base_url}/{model_id}".replace("/api/v1/execute", "/api/v2/execute")

ref_response = {
"completed": True,
"status": "SUCCESS",
"data": "Test Result",
"usedCredits": 10,
"runTime": 1.5,
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
}

with requests_mock.Mocker() as mock:
mock.post(execute_url, json=ref_response)

test_model = LLM(id=model_id, name="Test Model", function=Function.TEXT_GENERATION, url=base_url)

custom_params = {"custom_param": "value", "temperature": 0.8} # This should override the default

response = test_model.run(data="test input", temperature=0.5, parameters=custom_params)

assert isinstance(response, ModelResponse)
assert response.status == ResponseStatus.SUCCESS
assert response.data == "Test Result"
assert response.used_credits == 10
assert response.run_time == 1.5
assert response.usage == {"prompt_tokens": 10, "completion_tokens": 20}
Loading