Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
a7ac7ad
ENG-1852: Hotfix, added fail-fast option (#488)
kadirpekel Apr 11, 2025
b0887bd
Fix: BUG-503 failed to update utility tool after get (#491)
ahmetgunduz Apr 16, 2025
89daf2f
Bug-503-fix Update (#493)
ahmetgunduz Apr 16, 2025
6bf7d00
ENG-1920: added sentry cred (#476)
xainaz Apr 16, 2025
9e2b03c
Read input variables correctly (#496)
thiago-aixplain Apr 17, 2025
409a893
Fix: file not deleted in test_sql_tool_with_csv test (#492)
ahmetgunduz Apr 17, 2025
3acfacd
ENG-2007: Fix agent and team agent parametrized functional test - Syn…
lucas-aixplain Apr 17, 2025
da7ec38
ENG-1789: Add multiple index backbones support (#443)
basitanees Apr 23, 2025
d360780
ENG-2049: rename air functions (#500)
thiago-aixplain Apr 24, 2025
24e88da
ENG 1924: aixplain sdk new test cases for agents using utility and pi…
OsujiCC Apr 24, 2025
044d8d8
add pydantic requirement (#502)
basitanees Apr 25, 2025
4efdbc8
ENG-1978: Adding instructions to teams (#485)
thiago-aixplain Apr 25, 2025
ab2fcc5
BUG-504: Merged paramMappings for the same link vectors (#499)
kadirpekel Apr 25, 2025
268cc1a
ENG-1836: Set name of tools on the SDK (#501)
thiago-aixplain Apr 28, 2025
9796d19
Eng 2051 Improvements on CI flow (#509)
kadirpekel Apr 30, 2025
f0837fc
Add a finetuned version of BGE model (#512)
Muhammad-Elmallah May 5, 2025
511bf5f
ENG-2055-Aixplain-SDK-Centralized-Error-Handling (#510)
ahmetgunduz May 6, 2025
0c4edf4
ENG-1862:Added status to Tools and deployment check for Agent and Tea…
ahmetgunduz May 6, 2025
0989c51
Merge branch 'test' into development
hadi-aix May 8, 2025
84baed8
Fix: BUG-543 SQL Tool upload db issue (#519)
ahmetgunduz May 9, 2025
de6a1c4
ENG-2028: model streaming (#506)
thiago-aixplain May 9, 2025
b909a54
BUG-542-Utility-Model-Update-Test-Failing-in-SDK (#515)
ahmetgunduz May 9, 2025
8aa3d31
ENG-2115 fixing agent tests (#522)
thiago-aixplain May 9, 2025
4f20c65
Merge branch 'test' into development
thiago-aixplain May 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions aixplain/factories/agent_factory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,9 @@ def create_model_tool(
supplier = supplier_
break
assert isinstance(supplier, Supplier), f"Supplier {supplier} is not a valid supplier"
return ModelTool(function=function, supplier=supplier, model=model, description=description, parameters=parameters)
return ModelTool(
function=function, supplier=supplier, model=model, name=name, description=description, parameters=parameters
)

@classmethod
def create_pipeline_tool(
Expand Down Expand Up @@ -278,7 +280,6 @@ def create_sql_tool(
base_name = os.path.splitext(os.path.basename(source))[0]
db_path = os.path.join(os.path.dirname(source), f"{base_name}.db")
table_name = tables[0] if tables else None

try:
# Create database from CSV
schema = create_database_from_csv(source, db_path, table_name)
Expand Down
1 change: 1 addition & 0 deletions aixplain/factories/model_factory/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def create_model_from_response(response: Dict) -> Model:
version=response["version"]["id"],
inputs=inputs,
temperature=temperature,
supports_streaming=response.get("supportsStreaming", False),
status=status,
)

Expand Down
2 changes: 1 addition & 1 deletion aixplain/modules/agent/tool/sql_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def __init__(
self.tables = tables if isinstance(tables, list) else [tables] if tables else None
self.enable_commit = enable_commit
self.status = AssetStatus.ONBOARDED # TODO: change to DRAFT when we have a way to onboard the tool
self.validate() # to upload the database

def to_dict(self) -> Dict[str, Text]:
return {
Expand All @@ -306,7 +307,6 @@ def validate(self):
raise SQLToolError("Description is required")
if not self.database:
raise SQLToolError("Database must be provided")

# Handle database validation
if not (
str(self.database).startswith("s3://")
Expand Down
27 changes: 24 additions & 3 deletions aixplain/modules/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import traceback
from aixplain.enums import Supplier, Function
from aixplain.modules.asset import Asset
from aixplain.modules.model.model_response_streamer import ModelResponseStreamer
from aixplain.modules.model.utils import build_payload, call_run_endpoint
from aixplain.utils import config
from urllib.parse import urljoin
Expand Down Expand Up @@ -56,6 +57,7 @@ class Model(Asset):
input_params (ModelParameters, optional): input parameters for the function.
output_params (Dict, optional): output parameters for the function.
model_params (ModelParameters, optional): parameters for the function.
supports_streaming (bool, optional): whether the model supports streaming. Defaults to False.
"""

def __init__(
Expand All @@ -73,6 +75,7 @@ def __init__(
input_params: Optional[Dict] = None,
output_params: Optional[Dict] = None,
model_params: Optional[Dict] = None,
supports_streaming: bool = False,
status: Optional[AssetStatus] = AssetStatus.ONBOARDED, # default status for models is ONBOARDED
**additional_info,
) -> None:
Expand All @@ -91,6 +94,7 @@ def __init__(
input_params (Dict, optional): input parameters for the function.
output_params (Dict, optional): output parameters for the function.
model_params (Dict, optional): parameters for the function.
supports_streaming (bool, optional): whether the model supports streaming. Defaults to False.
status (AssetStatus, optional): status of the model. Defaults to None.
**additional_info: Any additional Model info to be saved
"""
Expand All @@ -105,6 +109,7 @@ def __init__(
self.input_params = input_params
self.output_params = output_params
self.model_params = ModelParameters(model_params) if model_params else None
self.supports_streaming = supports_streaming
if isinstance(status, str):
try:
status = AssetStatus(status)
Expand Down Expand Up @@ -232,14 +237,28 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> ModelResponse:
completed=False,
)

def run_stream(
self,
data: Union[Text, Dict],
parameters: Optional[Dict] = None,
) -> ModelResponseStreamer:
assert self.supports_streaming, f"Model '{self.name} ({self.id})' does not support streaming"
payload = build_payload(data=data, parameters=parameters, stream=True)
url = f"{self.url}/{self.id}".replace("api/v1/execute", "api/v2/execute")
logging.debug(f"Model Run Stream: Start service for {url} - {payload}")
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}
r = _request_with_retry("post", url, headers=headers, data=payload, stream=True)
return ModelResponseStreamer(r.iter_lines(decode_unicode=True))

def run(
self,
data: Union[Text, Dict],
name: Text = "model_process",
timeout: float = 300,
parameters: Optional[Dict] = None,
wait_time: float = 0.5,
) -> ModelResponse:
stream: bool = False,
) -> Union[ModelResponse, ModelResponseStreamer]:
"""Runs a model call.

Args:
Expand All @@ -248,10 +267,12 @@ def run(
timeout (float, optional): total polling time. Defaults to 300.
parameters (Dict, optional): optional parameters to the model. Defaults to None.
wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5.

stream (bool, optional): whether the model supports streaming. Defaults to False.
Returns:
Dict: parsed output from model
Union[ModelResponse, ModelStreamer]: parsed output from model
"""
if stream:
return self.run_stream(data=data, parameters=parameters)
start = time.time()
payload = build_payload(data=data, parameters=parameters)
url = f"{self.url}/{self.id}".replace("api/v1/execute", "api/v2/execute")
Expand Down
13 changes: 9 additions & 4 deletions aixplain/modules/model/llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import traceback
from aixplain.enums import Function, Supplier
from aixplain.modules.model import Model
from aixplain.modules.model.model_response_streamer import ModelResponseStreamer
from aixplain.modules.model.utils import build_payload, call_run_endpoint
from aixplain.utils import config
from typing import Union, Optional, List, Text, Dict
Expand Down Expand Up @@ -108,7 +109,8 @@ def run(
timeout: float = 300,
parameters: Optional[Dict] = None,
wait_time: float = 0.5,
) -> ModelResponse:
stream: bool = False,
) -> Union[ModelResponse, ModelResponseStreamer]:
"""Synchronously running a Large Language Model (LLM) model.

Args:
Expand All @@ -123,9 +125,9 @@ def run(
timeout (float, optional): total polling time. Defaults to 300.
parameters (Dict, optional): optional parameters to the model. Defaults to None.
wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5.

stream (bool, optional): whether the model supports streaming. Defaults to False.
Returns:
Dict: parsed output from model
Union[ModelResponse, ModelStreamer]: parsed output from model
"""
start = time.time()
parameters = parameters or {}
Expand All @@ -141,10 +143,13 @@ def run(
parameters.setdefault("max_tokens", max_tokens)
parameters.setdefault("top_p", top_p)

if stream:
return self.run_stream(data=data, parameters=parameters)

payload = build_payload(data=data, parameters=parameters)
logging.info(payload)
url = f"{self.url}/{self.id}".replace("/api/v1/execute", "/api/v2/execute")
logging.debug(f"Model Run Sync: Start service for {name} - {url}")
logging.debug(f"Model Run Sync: Start service for {name} - {url} - {payload}")
response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key)
if response["status"] == "IN_PROGRESS":
try:
Expand Down
28 changes: 28 additions & 0 deletions aixplain/modules/model/model_response_streamer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import json
from typing import Iterator

from aixplain.modules.model.response import ModelResponse, ResponseStatus


class ModelResponseStreamer:
def __init__(self, iterator: Iterator):
self.iterator = iterator
self.status = ResponseStatus.IN_PROGRESS

def __next__(self):
"""
Returns the next chunk of the response.
"""
line = next(self.iterator).replace("data: ", "")
try:
data = json.loads(line)
except json.JSONDecodeError:
data = {"data": line}
content = data.get("data", "")
if content == "[DONE]":
self.status = ResponseStatus.SUCCESS
content = ""
return ModelResponse(status=self.status, data=content)

def __iter__(self):
return self
1 change: 0 additions & 1 deletion aixplain/modules/model/utility_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ def update(self):
DeprecationWarning,
stacklevel=2,
)

self.validate()
url = urljoin(self.backend_url, f"sdk/utilities/{self.id}")
headers = {"x-api-key": f"{self.api_key}", "Content-Type": "application/json"}
Expand Down
8 changes: 7 additions & 1 deletion aixplain/modules/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@
from aixplain.exceptions import get_error_from_status_code


def build_payload(data: Union[Text, Dict], parameters: Optional[Dict] = None):
def build_payload(data: Union[Text, Dict], parameters: Optional[Dict] = None, stream: Optional[bool] = None):
from aixplain.factories import FileFactory

if parameters is None:
parameters = {}

if stream is not None:
if "options" not in parameters:
parameters["options"] = {}
parameters["options"]["stream"] = stream

data = FileFactory.to_link(data)
if isinstance(data, dict):
payload = data
Expand All @@ -36,6 +41,7 @@ def call_run_endpoint(url: Text, api_key: Text, payload: Dict) -> Dict:

resp = "unspecified error"
try:
logging.debug(f"Calling {url} with payload: {payload}")
r = _request_with_retry("post", url, headers=headers, data=payload)
resp = r.json()
except Exception as e:
Expand Down
2 changes: 2 additions & 0 deletions aixplain/v2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def create_custom_python_code_tool(
@classmethod
def create_sql_tool(
cls,
name: str,
description: str,
source: str,
source_type: str,
Expand Down Expand Up @@ -172,6 +173,7 @@ def create_sql_tool(
from aixplain.factories import AgentFactory

return AgentFactory.create_sql_tool(
name=name,
description=description,
source=source,
source_type=source_type,
Expand Down
31 changes: 17 additions & 14 deletions tests/functional/agent/agent_functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def test_custom_code_tool(delete_agents_and_team_agents, AgentFactory):
)
assert tool is not None
assert tool.description == "Add two strings"
assert tool.code == 'def main(aaa: str, bbb: str) -> str:\n """Add two strings"""\n return aaa + bbb'
assert tool.code.startswith("s3://")
agent = AgentFactory.create(
name="Add Strings Agent",
description="Add two strings. Do not directly answer. Use the tool to add the strings.",
Expand Down Expand Up @@ -354,6 +354,7 @@ def test_specific_model_parameters_e2e(tool_config, delete_agents_and_team_agent
@pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent])
def test_sql_tool(delete_agents_and_team_agents, AgentFactory):
assert delete_agents_and_team_agents
agent = None
try:
import os

Expand All @@ -362,6 +363,7 @@ def test_sql_tool(delete_agents_and_team_agents, AgentFactory):
f.write("")

tool = AgentFactory.create_sql_tool(
name="Teste",
description="Execute an SQL query and return the result",
source="ftest.db",
source_type="sqlite",
Expand Down Expand Up @@ -394,11 +396,13 @@ def test_sql_tool(delete_agents_and_team_agents, AgentFactory):
assert "eve" in str(response["data"]["output"]).lower()
finally:
os.remove("ftest.db")
agent.delete()
if agent:
agent.delete()

@pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent])
def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory):
assert delete_agents_and_team_agents
agent = None
try:
import os
import pandas as pd
Expand All @@ -424,7 +428,11 @@ def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory):

# Create SQL tool from CSV
tool = AgentFactory.create_sql_tool(
description="Execute SQL queries on employee data", source="test.csv", source_type="csv", tables=["employees"]
name="CSV Tool Test",
description="Execute SQL queries on employee data",
source="test.csv",
source_type="csv",
tables=["employees"],
)

# Verify tool setup
Expand Down Expand Up @@ -470,9 +478,12 @@ def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory):

finally:
# Cleanup
os.remove("test.csv")
os.remove("test.db")
agent.delete()
if agent:
agent.delete()
if os.path.exists("test.csv"):
os.remove("test.csv")
if os.path.exists("test.db"):
os.remove("test.db")


@pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent])
Expand All @@ -499,13 +510,6 @@ def test_instructions(delete_agents_and_team_agents, AgentFactory):
assert response["data"]["session_id"] is not None
assert response["data"]["output"] is not None
assert "aixplain" in response["data"]["output"].lower()
assert "eve" in response["data"]["output"].lower()

import os

# Cleanup
os.remove("test.csv")
os.remove("test.db")
agent.delete()


Expand Down Expand Up @@ -595,4 +599,3 @@ def test_agent_with_pipeline_tool(delete_agents_and_team_agents, AgentFactory):

assert "hello" in answer["data"]["output"].lower()
assert "hello pipeline" in answer["data"]["intermediate_steps"][0]["tool_steps"][0]["tool"].lower()

21 changes: 19 additions & 2 deletions tests/functional/model/run_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@ def test_llm_run(llm_model):
assert response["status"] == "SUCCESS"


def test_llm_run_stream():
"""Testing LLMs with streaming"""
from aixplain.modules.model.response import ModelResponse, ResponseStatus
from aixplain.modules.model.model_response_streamer import ModelResponseStreamer

llm_model = ModelFactory.get("669a63646eb56306647e1091")

assert isinstance(llm_model, LLM)
response = llm_model.run(
data="This is a test prompt where I expect you to respond with the following phrase: 'This is a test response.'",
stream=True,
)
assert isinstance(response, ModelResponseStreamer)
for chunk in response:
assert isinstance(chunk, ModelResponse)
assert chunk.data in "This is a test response."
assert response.status == ResponseStatus.SUCCESS


def test_run_async():
"""Testing Model Async"""
model = ModelFactory.get("60ddef828d38c51c5885d491")
Expand Down Expand Up @@ -100,7 +119,6 @@ def run_index_model(index_model):
pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="AIR - Multilingual E5 Large"),
pytest.param(EmbeddingModel.BGE_M3, AirParams, id="AIR - BGE M3"),
pytest.param(EmbeddingModel.AIXPLAIN_LEGAL_EMBEDDINGS, AirParams, id="AIR - aiXplain Legal Embeddings"),

],
)
def test_index_model(embedding_model, supplier_params):
Expand All @@ -126,7 +144,6 @@ def test_index_model(embedding_model, supplier_params):
pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="Multilingual E5 Large"),
pytest.param(EmbeddingModel.BGE_M3, AirParams, id="BGE M3"),
pytest.param(EmbeddingModel.AIXPLAIN_LEGAL_EMBEDDINGS, AirParams, id="aiXplain Legal Embeddings"),

],
)
def test_index_model_with_filter(embedding_model, supplier_params):
Expand Down
Loading