diff --git a/aixplain/factories/agent_factory/utils.py b/aixplain/factories/agent_factory/utils.py index 54f746d8..d6857468 100644 --- a/aixplain/factories/agent_factory/utils.py +++ b/aixplain/factories/agent_factory/utils.py @@ -19,12 +19,13 @@ def build_agent(payload: Dict, api_key: Text = config.TEAM_API_KEY) -> Agent: if tool["type"] == "model": supplier = "aixplain" for supplier_ in Supplier: - if tool["supplier"] is not None and tool["supplier"].lower() in [ - supplier_.value["code"].lower(), - supplier_.value["name"].lower(), - ]: - supplier = supplier_ - break + if isinstance(tool["supplier"], str): + if tool["supplier"] is not None and tool["supplier"].lower() in [ + supplier_.value["code"].lower(), + supplier_.value["name"].lower(), + ]: + supplier = supplier_ + break tool = ModelTool( function=Function(tool.get("function", None)), diff --git a/aixplain/factories/pipeline_factory/utils.py b/aixplain/factories/pipeline_factory/utils.py index 08954571..2a7de16b 100644 --- a/aixplain/factories/pipeline_factory/utils.py +++ b/aixplain/factories/pipeline_factory/utils.py @@ -14,6 +14,8 @@ Route, Script, Link, + BareSegmentor, + BareReconstructor, ) from typing import Dict @@ -32,7 +34,9 @@ def build_from_response(response: Dict, load_architecture: bool = False) -> Pipe response["api_key"] = config.TEAM_API_KEY # instantiating pipeline generic info - pipeline = Pipeline(response["id"], response["name"], response["api_key"]) + pipeline = Pipeline( + id=response["id"], name=response["name"], api_key=response["api_key"], status=response.get("status", "draft") + ) if load_architecture is True: try: # instantiating nodes @@ -45,28 +49,45 @@ def build_from_response(response: Dict, load_architecture: bool = False) -> Pipe elif node_json["type"].lower() == "asset": if node_json["functionType"] == "metric": node = BareMetric(asset_id=node_json["assetId"]) + elif node_json["functionType"] == "reconstructor": + node = BareReconstructor(asset_id=node_json["assetId"]) + elif node_json["functionType"] == "segmentor": + node = BareSegmentor(asset_id=node_json["assetId"]) else: node = BareAsset(asset_id=node_json["assetId"]) - elif node_json["type"].lower() == "segmentor": - raise NotImplementedError() - elif node_json["type"].lower() == "reconstructor": - raise NotImplementedError() elif node_json["type"].lower() == "decision": - node = Decision(routes=[Route(**route) for route in node_json["routes"]]) + node = Decision( + routes=[Route(**route) for route in node_json["routes"]] + ) elif node_json["type"].lower() == "router": - node = Router(routes=[Route(**route) for route in node_json["routes"]]) + node = Router( + routes=[Route(**route) for route in node_json["routes"]] + ) elif node_json["type"].lower() == "script": - node = Script(fileId=node_json["fileId"], fileMetadata=node_json["fileMetadata"]) + node = Script( + fileId=node_json["fileId"], + fileMetadata=node_json["fileMetadata"], + ) elif node_json["type"].lower() == "output": node = Output() if "inputValues" in node_json: [ node.inputs.create_param( - data_type=DataType(input_param["dataType"]) if "dataType" in input_param else None, + data_type=( + DataType(input_param["dataType"]) + if "dataType" in input_param + else None + ), code=input_param["code"], - value=input_param["value"] if "value" in input_param else None, - is_required=input_param["isRequired"] if "isRequired" in input_param else False, + value=( + input_param["value"] if "value" in input_param else None + ), + is_required=( + input_param["isRequired"] + if "isRequired" in input_param + else False + ), ) for input_param in node_json["inputValues"] if input_param["code"] not in node.inputs @@ -74,10 +95,22 @@ def build_from_response(response: Dict, load_architecture: bool = False) -> Pipe if "outputValues" in node_json: [ node.outputs.create_param( - data_type=DataType(output_param["dataType"]) if "dataType" in output_param else None, + data_type=( + DataType(output_param["dataType"]) + if "dataType" in output_param + else None + ), code=output_param["code"], - value=output_param["value"] if "value" in output_param else None, - is_required=output_param["isRequired"] if "isRequired" in output_param else False, + value=( + output_param["value"] + if "value" in output_param + else None + ), + is_required=( + output_param["isRequired"] + if "isRequired" in output_param + else False + ), ) for output_param in node_json["outputValues"] if output_param["code"] not in node.outputs diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 581c7e88..5ff9ff69 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -92,6 +92,8 @@ def __init__( super().__init__(id, name, description, api_key, supplier, version, cost=cost) self.additional_info = additional_info self.tools = tools + for i, _ in enumerate(tools): + self.tools[i].api_key = api_key self.llm_id = llm_id if isinstance(status, str): try: @@ -110,7 +112,7 @@ def validate(self) -> None: ), "Agent Creation Error: Agent name must not contain special characters." try: - llm = ModelFactory.get(self.llm_id) + llm = ModelFactory.get(self.llm_id, api_key=self.api_key) assert llm.function == Function.TEXT_GENERATION, "Large Language Model must be a text generation model." except Exception: raise Exception(f"Large Language Model with ID '{self.llm_id}' not found.") @@ -307,19 +309,19 @@ def delete(self) -> None: message = f"Agent Deletion Error (HTTP {r.status_code}): There was an error in deleting the agent." logging.error(message) raise Exception(f"{message}") - + def update(self) -> None: """Update agent.""" import warnings import inspect + # Get the current call stack stack = inspect.stack() - if len(stack) > 2 and stack[1].function != 'save': + if len(stack) > 2 and stack[1].function != "save": warnings.warn( - "update() is deprecated and will be removed in a future version. " - "Please use save() instead.", + "update() is deprecated and will be removed in a future version. " "Please use save() instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) from aixplain.factories.agent_factory.utils import build_agent @@ -343,10 +345,9 @@ def update(self) -> None: error_msg = f"Agent Update Error (HTTP {r.status_code}): {resp}" raise Exception(error_msg) - def save(self) -> None: """Save the Agent.""" - self.update() + self.update() def deploy(self) -> None: assert self.status == AssetStatus.DRAFT, "Agent must be in draft status to be deployed." diff --git a/aixplain/modules/agent/tool/__init__.py b/aixplain/modules/agent/tool/__init__.py index 01b44dfa..aefa093a 100644 --- a/aixplain/modules/agent/tool/__init__.py +++ b/aixplain/modules/agent/tool/__init__.py @@ -22,6 +22,7 @@ """ from abc import ABC from typing import Optional, Text +from aixplain.utils import config class Tool(ABC): @@ -38,6 +39,7 @@ def __init__( name: Text, description: Text, version: Optional[Text] = None, + api_key: Optional[Text] = config.TEAM_API_KEY, **additional_info, ) -> None: """Specialized software or resource designed to assist the AI in executing specific tasks or functions based on user commands. @@ -46,10 +48,12 @@ def __init__( name (Text): name of the tool description (Text): descriptiion of the tool version (Text): version of the tool + api_key (Text): api key of the tool. Defaults to config.TEAM_API_KEY. """ self.name = name self.description = description self.version = version + self.api_key = api_key self.additional_info = additional_info def to_dict(self): diff --git a/aixplain/modules/agent/tool/model_tool.py b/aixplain/modules/agent/tool/model_tool.py index 0b1c3179..bdbe0f5f 100644 --- a/aixplain/modules/agent/tool/model_tool.py +++ b/aixplain/modules/agent/tool/model_tool.py @@ -108,7 +108,7 @@ def validate(self) -> Model: try: model = None if self.model is not None: - model = ModelFactory.get(self.model) + model = ModelFactory.get(self.model, api_key=self.api_key) return model except Exception: raise Exception(f"Model Tool Unavailable. Make sure Model '{self.model}' exists or you have access to it.") diff --git a/aixplain/modules/agent/tool/pipeline_tool.py b/aixplain/modules/agent/tool/pipeline_tool.py index 9ea7a5fb..ab3b4311 100644 --- a/aixplain/modules/agent/tool/pipeline_tool.py +++ b/aixplain/modules/agent/tool/pipeline_tool.py @@ -62,6 +62,6 @@ def validate(self): from aixplain.factories.pipeline_factory import PipelineFactory try: - PipelineFactory.get(self.pipeline) + PipelineFactory.get(self.pipeline, api_key=self.api_key) except Exception: raise Exception(f"Pipeline Tool Unavailable. Make sure Pipeline '{self.pipeline}' exists or you have access to it.") diff --git a/aixplain/modules/model/index_model.py b/aixplain/modules/model/index_model.py index 67b3f8f7..fae597b8 100644 --- a/aixplain/modules/model/index_model.py +++ b/aixplain/modules/model/index_model.py @@ -3,7 +3,7 @@ from aixplain.utils import config from aixplain.modules.model.response import ModelResponse from typing import Text, Optional, Union, Dict -from aixplain.modules.model.document_index import DocumentIndex +from aixplain.modules.model.record import Record from typing import List @@ -51,30 +51,18 @@ def __init__( self.url = config.MODELS_RUN_URL self.backend_url = config.BACKEND_URL - def search(self, query: str, top_k: int = 10) -> ModelResponse: - data = {"action": "search", "data": query, "payload": {"filters": {}, "top_k": top_k}} + def search(self, query: str, top_k: int = 10, filters: Dict = {}) -> ModelResponse: + data = {"action": "search", "data": query, "payload": {"filters": filters, "top_k": top_k}} return self.run(data=data) - def add(self, documents: List[DocumentIndex]) -> ModelResponse: + def upsert(self, documents: List[Record]) -> ModelResponse: payloads = [doc.to_dict() for doc in documents] data = {"action": "ingest", "data": "", "payload": {"payloads": payloads}} response = self.run(data=data) if response.status == ResponseStatus.SUCCESS: response.data = payloads return response - raise Exception(f"Failed to add documents: {response.error_message}") - - def update(self, documents: List[DocumentIndex]) -> ModelResponse: - payloads = [ - {"value": doc.value, "value_type": doc.value_type, "id": str(doc.id), "uri": doc.uri, "attributes": doc.attributes} - for doc in documents - ] - data = {"action": "update", "data": "", "payload": {"payloads": payloads}} - response = self.run(data=data) - if response.status == ResponseStatus.SUCCESS: - response.data = payloads - return response - raise Exception(f"Failed to update documents: {response.error_message}") + raise Exception(f"Failed to upsert documents: {response.error_message}") def count(self) -> float: data = {"action": "count", "data": ""} diff --git a/aixplain/modules/model/document_index.py b/aixplain/modules/model/record.py similarity index 96% rename from aixplain/modules/model/document_index.py rename to aixplain/modules/model/record.py index 12562931..a3c57173 100644 --- a/aixplain/modules/model/document_index.py +++ b/aixplain/modules/model/record.py @@ -2,7 +2,7 @@ from uuid import uuid4 -class DocumentIndex: +class Record: def __init__(self, value: str, value_type: str = "text", id: Optional[str] = None, uri: str = "", attributes: dict = {}): self.value = value self.value_type = value_type diff --git a/aixplain/modules/model/response.py b/aixplain/modules/model/response.py index 1576c1f4..9cbbe4d8 100644 --- a/aixplain/modules/model/response.py +++ b/aixplain/modules/model/response.py @@ -22,6 +22,10 @@ def __init__( self.data = data self.details = details self.completed = completed + if error_message == "": + error_message = kwargs.get("error", "") + if "supplierError" in kwargs: + error_message = f"{error_message} - {kwargs.get('supplierError', '')}" self.error_message = error_message self.used_credits = used_credits self.run_time = run_time diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py index b5748ca7..474d31fa 100644 --- a/aixplain/modules/model/utility_model.py +++ b/aixplain/modules/model/utility_model.py @@ -115,18 +115,39 @@ def __init__( self.output_examples = output_examples def validate(self): - self.code, inputs, description = parse_code(self.code) + """Validate the Utility Model.""" + description = None + inputs = [] + # check if the model exists and if the code is strring with s3:// + # if not, parse the code and update the description and inputs and do the validation + # if yes, just do the validation on the description and inputs + if not (self._model_exists() and str(self.code).startswith("s3://")): + self.code, inputs, description = parse_code(self.code) + if self.description is None: + self.description = description + if len(self.inputs) == 0: + self.inputs = inputs + for input in self.inputs: + input.validate() + else: + logging.info("Utility Model Already Exists, skipping code validation") + assert description is not None or self.description is not None, "Utility Model Error: Model description is required" - if self.description is None: - self.description = description - if len(self.inputs) == 0: - self.inputs = inputs - for input in self.inputs: - input.validate() assert self.name and self.name.strip() != "", "Name is required" assert self.description and self.description.strip() != "", "Description is required" assert self.code and self.code.strip() != "", "Code is required" + def _model_exists(self): + if self.id is None or self.id == "": + return False + url = urljoin(self.backend_url, f"sdk/models/{self.id}") + headers = {"Authorization": f"Token {self.api_key}", "Content-Type": "application/json"} + logging.info(f"Start service for GET Model - {url} - {headers}") + r = _request_with_retry("get", url, headers=headers) + if r.status_code != 200: + raise Exception() + return True + def to_dict(self): return { "name": self.name, diff --git a/aixplain/modules/model/utils.py b/aixplain/modules/model/utils.py index f3691928..f2cf6209 100644 --- a/aixplain/modules/model/utils.py +++ b/aixplain/modules/model/utils.py @@ -102,7 +102,6 @@ def parse_code(code: Union[Text, Callable]) -> Tuple[Text, List, Text]: str_code = requests.get(code).text else: str_code = code - # assert str_code has a main function if "def main(" not in str_code: raise Exception("Utility Model Error: Code must have a main function") diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index 10ee3bf0..88364873 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -25,6 +25,7 @@ import json import os import logging +from aixplain.enums.asset_status import AssetStatus from aixplain.modules.asset import Asset from aixplain.utils import config from aixplain.utils.file_utils import _request_with_retry @@ -56,6 +57,7 @@ def __init__( url: Text = config.BACKEND_URL, supplier: Text = "aiXplain", version: Text = "1.0", + status: AssetStatus = AssetStatus.DRAFT, **additional_info, ) -> None: """Create a Pipeline with the necessary information @@ -67,6 +69,7 @@ def __init__( url (Text, optional): running URL of platform. Defaults to config.BACKEND_URL. supplier (Text, optional): Pipeline supplier. Defaults to "aiXplain". version (Text, optional): version of the pipeline. Defaults to "1.0". + status (AssetStatus, optional): Pipeline status. Defaults to AssetStatus.DRAFT. **additional_info: Any additional Pipeline info to be saved """ if not name: @@ -75,6 +78,12 @@ def __init__( super().__init__(id, name, "", supplier, version) self.api_key = api_key self.url = f"{url}/assets/pipeline/execution/run" + if isinstance(status, str): + try: + status = AssetStatus(status) + except Exception: + status = AssetStatus.DRAFT + self.status = status self.additional_info = additional_info def __polling( @@ -224,7 +233,6 @@ def run( data_asset=data_asset, name=name, batch_mode=batch_mode, - version=version, **kwargs, ) @@ -235,20 +243,7 @@ def run( poll_url = response["url"] end = time.time() - response = self.__polling( - poll_url, name=name, timeout=timeout, wait_time=wait_time - ) - - if self._should_fallback_to_v2(response, version): - return self.run( - data, - data_asset=data_asset, - name=name, - batch_mode=batch_mode, - version=self.VERSION_2_0, - **kwargs, - ) - response["version"] = version + response = self.__polling(poll_url, name=name, timeout=timeout, wait_time=wait_time) return response except Exception as e: error_message = f"Error in request for {name}: {str(e)}" @@ -441,16 +436,6 @@ def run_async( if resp is not None: response["error"] = resp - if self._should_fallback_to_v2(response, version): - return self.run_async( - data, - data_asset=data_asset, - name=name, - batch_mode=batch_mode, - version=self.VERSION_2_0, - **kwargs, - ) - response["version"] = version return response def update( @@ -477,8 +462,7 @@ def update( stack = inspect.stack() if len(stack) > 2 and stack[1].function != "save": warnings.warn( - "update() is deprecated and will be removed in a future version. " - "Please use save() instead.", + "update() is deprecated and will be removed in a future version. " "Please use save() instead.", DeprecationWarning, stacklevel=2, ) @@ -566,9 +550,7 @@ def save( ), "Pipeline Update Error: Make sure the pipeline to be saved is in a JSON file." with open(pipeline) as f: pipeline = json.load(f) - self.update( - pipeline=pipeline, save_as_asset=save_as_asset, api_key=api_key - ) + self.update(pipeline=pipeline, save_as_asset=save_as_asset, api_key=api_key) for i, node in enumerate(pipeline["nodes"]): if "functionType" in node: @@ -591,12 +573,19 @@ def save( "Authorization": f"Token {api_key}", "Content-Type": "application/json", } - logging.info( - f"Start service for Save Pipeline - {url} - {headers} - {json.dumps(payload)}" - ) - r = _request_with_retry(method, url, headers=headers, json=payload) + logging.info(f"Start service for Save Pipeline - {url} - {headers} - {json.dumps(payload)}") + r = _request_with_retry("post", url, headers=headers, json=payload) response = r.json() self.id = response["id"] logging.info(f"Pipeline {response['id']} Saved.") except Exception as e: raise Exception(e) + + def deploy(self, api_key: Optional[Text] = None) -> None: + """Deploy the Pipeline.""" + assert self.status == "draft", "Pipeline Deployment Error: Pipeline must be in draft status." + assert self.status != "onboarded", "Pipeline Deployment Error: Pipeline must be onboarded." + + pipeline = self.to_dict() + self.update(pipeline=pipeline, save_as_asset=True, api_key=api_key, name=self.name) + self.status = AssetStatus.ONBOARDED diff --git a/aixplain/modules/pipeline/designer/__init__.py b/aixplain/modules/pipeline/designer/__init__.py index 6a493aa4..7d880167 100644 --- a/aixplain/modules/pipeline/designer/__init__.py +++ b/aixplain/modules/pipeline/designer/__init__.py @@ -11,6 +11,8 @@ BaseMetric, BareAsset, BareMetric, + BareSegmentor, + BareReconstructor, ) from .pipeline import DesignerPipeline from .base import ( diff --git a/aixplain/modules/pipeline/designer/nodes.py b/aixplain/modules/pipeline/designer/nodes.py index 7e6e1803..fbe27991 100644 --- a/aixplain/modules/pipeline/designer/nodes.py +++ b/aixplain/modules/pipeline/designer/nodes.py @@ -474,19 +474,11 @@ class BaseReconstructor(AssetNode[TI, TO]): class ReconstructorInputs(Inputs): - data: InputParam = None - - def __init__(self, node: Node): - super().__init__(node) - self.data = self.create_param("data") + pass class ReconstructorOutputs(Outputs): - data: OutputParam = None - - def __init__(self, node: Node): - super().__init__(node) - self.data = self.create_param("data") + pass class BareReconstructor(BaseReconstructor[ReconstructorInputs, ReconstructorOutputs]): diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index 80729d80..b7094348 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -70,7 +70,7 @@ def __init__( version: Optional[Text] = None, cost: Optional[Dict] = None, use_mentalist_and_inspector: bool = True, - status: AssetStatus = AssetStatus.ONBOARDING, + status: AssetStatus = AssetStatus.DRAFT, **additional_info, ) -> None: """Create a FineTune with the necessary information. @@ -97,7 +97,7 @@ def __init__( try: status = AssetStatus(status) except Exception: - status = AssetStatus.ONBOARDING + status = AssetStatus.DRAFT self.status = status def run( @@ -286,8 +286,9 @@ def to_dict(self) -> Dict: "llmId": self.llm_id, "supervisorId": self.llm_id, "plannerId": self.llm_id if self.use_mentalist_and_inspector else None, - "supplier": self.supplier, + "supplier": self.supplier.value["code"] if isinstance(self.supplier, Supplier) else self.supplier, "version": self.version, + "status": self.status.value, } def validate(self) -> None: diff --git a/aixplain/utils/file_utils.py b/aixplain/utils/file_utils.py index 0e617397..d39ca2b9 100644 --- a/aixplain/utils/file_utils.py +++ b/aixplain/utils/file_utils.py @@ -153,7 +153,7 @@ def upload_data( raise Exception("File Uploading Error: Failure on Uploading to S3.") -def s3_to_csv(s3_url: Text, aws_credentials: Dict) -> Text: +def s3_to_csv(s3_url: Text, aws_credentials: Optional[Dict[Text, Text]] = {"AWS_ACCESS_KEY_ID": None, "AWS_SECRET_ACCESS_KEY": None}) -> Text: """Convert s3 url to a csv file and download the file in `download_path` Args: diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 214f31b9..314a56b2 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -121,12 +121,10 @@ def test_python_interpreter_tool(delete_agents_and_team_agents): def test_custom_code_tool(delete_agents_and_team_agents): assert delete_agents_and_team_agents tool = AgentFactory.create_custom_python_code_tool( - name="Add Numbers", description="Add two numbers", - code='def main(aaa: int, bbb: int) > int:\n """Add two numbers"""\n return aaa + bbb', + code='def main(aaa: int, bbb: int) -> int:\n """Add two numbers"""\n return aaa + bbb', ) assert tool is not None - assert tool.name == "Add Numbers" assert tool.description == "Add two numbers" assert tool.code == 'def main(aaa: int, bbb: int) -> int:\n """Add two numbers"""\n return aaa + bbb' agent = AgentFactory.create( diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 04e5da0d..47bd4f12 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -55,12 +55,19 @@ def test_run_async(): def test_index_model(): - from aixplain.modules.model.document_index import DocumentIndex + from uuid import uuid4 + from aixplain.modules.model.record import Record from aixplain.factories import IndexFactory - index_model = IndexFactory.create("test", "test") - index_model.add([DocumentIndex(value="Hello, world!", value_type="text", uri="", attributes={})]) + index_model = IndexFactory.create(name=str(uuid4()), description=str(uuid4())) + index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})]) response = index_model.search("Hello") assert str(response.status) == "SUCCESS" + assert "world" in response.data.lower() + assert index_model.count() == 1 + index_model.upsert([Record(value="Hello, aiXplain!", value_type="text", uri="", id="1", attributes={})]) + response = index_model.search("aiXplain") + assert str(response.status) == "SUCCESS" + assert "aixplain" in response.data.lower() assert index_model.count() == 1 index_model.delete() diff --git a/tests/functional/pipelines/create_test.py b/tests/functional/pipelines/create_test.py index 6cf3d718..2cad384a 100644 --- a/tests/functional/pipelines/create_test.py +++ b/tests/functional/pipelines/create_test.py @@ -43,6 +43,11 @@ def test_create_pipeline_from_string(): assert isinstance(pipeline, Pipeline) assert pipeline.id != "" + assert pipeline.status.value == "draft" + + pipeline.deploy() + pipeline = PipelineFactory.get(pipeline.id) + assert pipeline.status.value == "onboarded" pipeline.delete() diff --git a/tests/functional/pipelines/fallback_test.py b/tests/functional/pipelines/fallback_test.py deleted file mode 100644 index 4650bff3..00000000 --- a/tests/functional/pipelines/fallback_test.py +++ /dev/null @@ -1,15 +0,0 @@ -from aixplain.factories import PipelineFactory - - -def test_fallback_to_v2(): - pipeline = PipelineFactory.get("6750535166d4db27e14f07b1") - response = pipeline.run( - "https://homepage.ntu.edu.tw/~karchung/miniconversations/mc1.mp3" - ) - assert response["version"] == "3.0" - assert response["status"] == "SUCCESS" - - pipeline = PipelineFactory.get("6750535166d4db27e14f07b1") - response = pipeline.run("<>") - assert response["version"] == "2.0" - assert response["status"] == "ERROR" diff --git a/tests/functional/team_agent/team_agent_functional_test.py b/tests/functional/team_agent/team_agent_functional_test.py index e60e453a..a402f324 100644 --- a/tests/functional/team_agent/team_agent_functional_test.py +++ b/tests/functional/team_agent/team_agent_functional_test.py @@ -33,6 +33,7 @@ def read_data(data_path): return json.load(open(data_path, "r")) + @pytest.fixture(scope="function") def delete_agents_and_team_agents(): for team_agent in TeamAgentFactory.list()["results"]: @@ -94,6 +95,7 @@ def test_end2end(run_input_map, delete_agents_and_team_agents): team_agent.deploy() team_agent = TeamAgentFactory.get(team_agent.id) assert team_agent is not None + assert team_agent.status == AssetStatus.ONBOARDED response = team_agent.run(data=run_input_map["query"]) assert response is not None @@ -161,6 +163,7 @@ def test_fail_non_existent_llm(): ) assert str(exc_info.value) == "Large Language Model with ID 'non_existent_llm' not found." + def test_add_remove_agents_from_team_agent(run_input_map, delete_agents_and_team_agents): assert delete_agents_and_team_agents @@ -210,12 +213,12 @@ def test_add_remove_agents_from_team_agent(run_input_map, delete_agents_and_team assert new_agent.id in [agent.id for agent in team_agent.agents] assert len(team_agent.agents) == len(agents) + 1 - removed_agent = team_agent.agents.pop(0) + removed_agent = team_agent.agents.pop(0) team_agent.update() team_agent = TeamAgentFactory.get(team_agent.id) assert removed_agent.id not in [agent.id for agent in team_agent.agents] - assert len(team_agent.agents) == len(agents) + assert len(team_agent.agents) == len(agents) team_agent.delete() new_agent.delete() diff --git a/tests/unit/agent_test.py b/tests/unit/agent_test.py index 6c17a5b6..10997a75 100644 --- a/tests/unit/agent_test.py +++ b/tests/unit/agent_test.py @@ -9,7 +9,6 @@ from aixplain.modules.agent.utils import process_variables from urllib.parse import urljoin from unittest.mock import patch -import warnings from aixplain.enums.function import Function @@ -198,6 +197,8 @@ def test_to_dict(): description="Test Agent Description", llm_id="6646261c6eb563165658bbb1", tools=[AgentFactory.create_model_tool(function="text-generation")], + api_key="test_api_key", + status=AssetStatus.DRAFT, ) agent_json = agent.to_dict() @@ -207,6 +208,7 @@ def test_to_dict(): assert agent_json["llmId"] == "6646261c6eb563165658bbb1" assert agent_json["assets"][0]["function"] == "text-generation" assert agent_json["assets"][0]["type"] == "model" + assert agent_json["status"] == "draft" def test_update_success(): @@ -256,7 +258,10 @@ def test_update_success(): mock.get(url, headers=headers, json=model_ref_response) # Capture warnings - with pytest.warns(DeprecationWarning, match="update\(\) is deprecated and will be removed in a future version. Please use save\(\) instead."): + with pytest.warns( + DeprecationWarning, + match="update\(\) is deprecated and will be removed in a future version. Please use save\(\) instead.", + ): agent.update() assert agent.id == ref_response["id"] @@ -265,6 +270,7 @@ def test_update_success(): assert agent.llm_id == ref_response["llmId"] assert agent.tools[0].function.value == ref_response["assets"][0]["function"] + def test_save_success(): agent = Agent( id="123", @@ -310,8 +316,9 @@ def test_save_success(): "pricing": {"currency": "USD", "value": 0.0}, } mock.get(url, headers=headers, json=model_ref_response) - + import warnings + # Capture warnings with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") # Trigger all warnings @@ -328,6 +335,7 @@ def test_save_success(): assert agent.llm_id == ref_response["llmId"] assert agent.tools[0].function.value == ref_response["assets"][0]["function"] + def test_run_success(): agent = Agent("123", "Test Agent", "Sample Description") url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent.id}/run") @@ -369,3 +377,64 @@ def test_fail_utilities_without_model(): with pytest.raises(Exception) as exc_info: AgentFactory.create(name="Test", tools=[ModelTool(function=Function.UTILITIES)], llm_id="6646261c6eb563165658bbb1") assert str(exc_info.value) == "Agent Creation Error: Utility function must be used with an associated model." + + +def test_agent_api_key_propagation(): + """Test that the api_key is properly propagated to tools when creating an agent""" + custom_api_key = "custom_test_key" + tool = AgentFactory.create_model_tool(function="text-generation") + agent = Agent(id="123", name="Test Agent", description="Test Description", tools=[tool], api_key=custom_api_key) + + # Check that the agent has the correct api_key + assert agent.api_key == custom_api_key + # Check that the tool received the agent's api_key + assert agent.tools[0].api_key == custom_api_key + + +def test_agent_default_api_key(): + """Test that the default api_key is used when none is provided""" + tool = AgentFactory.create_model_tool(function="text-generation") + agent = Agent(id="123", name="Test Agent", description="Test Description", tools=[tool]) + + # Check that the agent has the default api_key + assert agent.api_key == config.TEAM_API_KEY + # Check that the tool has the default api_key + assert agent.tools[0].api_key == config.TEAM_API_KEY + + +def test_agent_multiple_tools_api_key(): + """Test that api_key is properly propagated to multiple tools""" + custom_api_key = "custom_test_key" + tools = [ + AgentFactory.create_model_tool(function="text-generation"), + AgentFactory.create_python_interpreter_tool(), + AgentFactory.create_custom_python_code_tool( + code="def main(query: str) -> str:\n return 'Hello'", description="Test Tool" + ), + ] + + agent = Agent(id="123", name="Test Agent", description="Test Description", tools=tools, api_key=custom_api_key) + + # Check that all tools received the agent's api_key + for tool in agent.tools: + assert tool.api_key == custom_api_key + + +def test_agent_api_key_in_requests(): + """Test that the api_key is properly used in API requests""" + custom_api_key = "custom_test_key" + agent = Agent(id="123", name="Test Agent", description="Test Description", api_key=custom_api_key) + + with requests_mock.Mocker() as mock: + url = agent.url + # The custom api_key should be used in the headers + headers = {"x-api-key": custom_api_key, "Content-Type": "application/json"} + ref_response = {"data": "test_url", "status": "IN_PROGRESS"} + mock.post(url, headers=headers, json=ref_response) + + response = agent.run_async(data={"query": "Test query"}) + + # Verify that the request was made with the correct api_key + assert mock.last_request.headers["x-api-key"] == custom_api_key + assert response["status"] == "IN_PROGRESS" + assert response["url"] == "test_url" diff --git a/tests/unit/index_model_test.py b/tests/unit/index_model_test.py index be9acc6f..dbf698cc 100644 --- a/tests/unit/index_model_test.py +++ b/tests/unit/index_model_test.py @@ -1,6 +1,6 @@ import requests_mock from aixplain.enums import Function, ResponseStatus -from aixplain.modules.model.document_index import DocumentIndex +from aixplain.modules.model.record import Record from aixplain.modules.model.response import ModelResponse from aixplain.modules.model.index_model import IndexModel from aixplain.utils import config @@ -28,8 +28,8 @@ def test_add_success(): mock_response = {"status": "SUCCESS"} mock_documents = [ - DocumentIndex(value="Sample document content 1", value_type="text", id=0, uri="", attributes={}), - DocumentIndex(value="Sample document content 2", value_type="text", id=1, uri="", attributes={}), + Record(value="Sample document content 1", value_type="text", id=0, uri="", attributes={}), + Record(value="Sample document content 2", value_type="text", id=1, uri="", attributes={}), ] with requests_mock.Mocker() as mock: @@ -37,7 +37,7 @@ def test_add_success(): index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) - response = index_model.add(mock_documents) + response = index_model.upsert(mock_documents) assert isinstance(response, ModelResponse) assert response.status == ResponseStatus.SUCCESS @@ -47,8 +47,8 @@ def test_update_success(): mock_response = {"status": "SUCCESS"} mock_documents = [ - DocumentIndex(value="Updated document content 1", value_type="text", id=0, uri="", attributes={}), - DocumentIndex(value="Updated document content 2", value_type="text", id=1, uri="", attributes={}), + Record(value="Updated document content 1", value_type="text", id=0, uri="", attributes={}), + Record(value="Updated document content 2", value_type="text", id=1, uri="", attributes={}), ] with requests_mock.Mocker() as mock: @@ -57,7 +57,7 @@ def test_update_success(): index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) - response = index_model.update(mock_documents) + response = index_model.upsert(mock_documents) assert isinstance(response, ModelResponse) assert response.status == ResponseStatus.SUCCESS diff --git a/tests/unit/pipeline_test.py b/tests/unit/pipeline_test.py index d1b0f9b2..913fe295 100644 --- a/tests/unit/pipeline_test.py +++ b/tests/unit/pipeline_test.py @@ -96,3 +96,18 @@ def test_get_pipeline_error_response(): PipelineFactory.get(pipeline_id=pipeline_id) assert "Pipeline GET Error: Failed to retrieve pipeline test-pipeline-id. Status Code: 404" in str(excinfo.value) + + +def test_deploy_pipeline(): + with requests_mock.Mocker() as mock: + pipeline_id = "test-pipeline-id" + url = urljoin(config.BACKEND_URL, f"sdk/pipelines/{pipeline_id}") + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + + mock.put(url, headers=headers, json={"status": "SUCCESS", "id": pipeline_id}) + + pipeline = Pipeline(id=pipeline_id, api_key=config.TEAM_API_KEY, name="Test Pipeline", url=config.BACKEND_URL) + pipeline.deploy() + + assert pipeline.id == pipeline_id + assert pipeline.status.value == "onboarded" diff --git a/tests/unit/team_agent_test.py b/tests/unit/team_agent_test.py index 56564b73..e6901cec 100644 --- a/tests/unit/team_agent_test.py +++ b/tests/unit/team_agent_test.py @@ -186,11 +186,32 @@ def test_create_team_agent(): llm_id="6646261c6eb563165658bbb1", agents=[agent], ) - assert team_agent.id is not None - assert team_agent.name == team_ref_response["name"] - assert team_agent.description == team_ref_response["description"] - assert team_agent.llm_id == team_ref_response["llmId"] - assert team_agent.use_mentalist_and_inspector is True - assert team_agent.status == AssetStatus.DRAFT - assert len(team_agent.agents) == 1 - assert team_agent.agents[0].id == team_ref_response["agents"][0]["assetId"] + assert team_agent.id is not None + assert team_agent.name == team_ref_response["name"] + assert team_agent.description == team_ref_response["description"] + assert team_agent.llm_id == team_ref_response["llmId"] + assert team_agent.use_mentalist_and_inspector is True + assert team_agent.status == AssetStatus.DRAFT + assert len(team_agent.agents) == 1 + assert team_agent.agents[0].id == team_ref_response["agents"][0]["assetId"] + + url = urljoin(config.BACKEND_URL, f"sdk/agent-communities/{team_agent.id}") + team_ref_response = { + "id": "team_agent_123", + "name": "TEST Multi agent", + "status": "onboarded", + "teamId": 645, + "description": "TEST Multi agent", + "llmId": "6646261c6eb563165658bbb1", + "assets": [], + "agents": [{"assetId": "123", "type": "AGENT", "number": 0, "label": "AGENT"}], + "links": [], + "plannerId": "6646261c6eb563165658bbb1", + "supervisorId": "6646261c6eb563165658bbb1", + "createdAt": "2024-10-28T19:30:25.344Z", + "updatedAt": "2024-10-28T19:30:25.344Z", + } + mock.put(url, headers=headers, json=team_ref_response) + + team_agent.deploy() + assert team_agent.status.value == "onboarded" diff --git a/tests/unit/utility_test.py b/tests/unit/utility_test.py index cb45597a..cd901ea0 100644 --- a/tests/unit/utility_test.py +++ b/tests/unit/utility_test.py @@ -95,9 +95,13 @@ def test_update_utility_model(): "utility_model_test", ), ): - mock.put(urljoin(config.BACKEND_URL, "sdk/utilities/123"), json={"id": "123"}) + # Mock both the model existence check and update endpoints + model_id = "123" + mock.get(urljoin(config.BACKEND_URL, f"sdk/models/{model_id}"), status_code=200) + mock.put(urljoin(config.BACKEND_URL, f"sdk/utilities/{model_id}"), json={"id": model_id}) + utility_model = UtilityModel( - id="123", + id=model_id, name="utility_model_test", description="utility_model_test", code="def main(originCode: str)", @@ -111,7 +115,7 @@ def test_update_utility_model(): utility_model.description = "updated_description" utility_model.update() - assert utility_model.id == "123" + assert utility_model.id == model_id assert utility_model.description == "updated_description" def test_save_utility_model(): @@ -126,9 +130,13 @@ def test_save_utility_model(): "utility_model_test", ), ): - mock.put(urljoin(config.BACKEND_URL, "sdk/utilities/123"), json={"id": "123"}) + # Mock both the model existence check and the update endpoint + model_id = "123" + mock.get(urljoin(config.BACKEND_URL, f"sdk/models/{model_id}"), status_code=200) + mock.put(urljoin(config.BACKEND_URL, f"sdk/utilities/{model_id}"), json={"id": model_id}) + utility_model = UtilityModel( - id="123", + id=model_id, name="utility_model_test", description="utility_model_test", code="def main(originCode: str)", @@ -137,6 +145,7 @@ def test_save_utility_model(): function=Function.UTILITIES, api_key=config.TEAM_API_KEY, ) + import warnings # it should not trigger any warning with warnings.catch_warnings(record=True) as w: @@ -146,7 +155,7 @@ def test_save_utility_model(): assert len(w) == 0 - assert utility_model.id == "123" + assert utility_model.id == model_id assert utility_model.description == "updated_description" @@ -218,3 +227,109 @@ def main(originCode): with pytest.raises(Exception) as exc_info: parse_code(code) assert str(exc_info.value) == "Utility Model Error: Unsupported input type: list" + +def test_validate_new_model(): + """Test validation for a new model""" + with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="def main(originCode: str)"): + with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="def main(originCode: str)"): + # Test with valid inputs + utility_model = UtilityModel( + id="", # Empty ID for new model + name="utility_model_test", + description="utility_model_test", + code="def main(originCode: str):\n return originCode", + output_examples="output_description", + function=Function.UTILITIES, + api_key=config.TEAM_API_KEY, + ) + utility_model.validate() # Should not raise any exception + + # Test with empty name + utility_model.name = "" + with pytest.raises(Exception) as exc_info: + utility_model.validate() + assert str(exc_info.value) == "Name is required" + + # Test with empty description + utility_model.name = "utility_model_test" + utility_model.description = "" + with pytest.raises(Exception) as exc_info: + utility_model.validate() + assert str(exc_info.value) == "Description is required" + + # Test with empty code + utility_model.description = "utility_model_test" + utility_model.code = "" + with pytest.raises(Exception) as exc_info: + utility_model.validate() + + assert str(exc_info.value) == "Utility Model Error: Code must have a main function" + +def test_validate_existing_model(): + """Test validation for an existing model with S3 code""" + with requests_mock.Mocker() as mock: + model_id = "123" + # Mock the model existence check + url = urljoin(config.BACKEND_URL, f"sdk/models/{model_id}") + mock.get(url, status_code=200) + + utility_model = UtilityModel( + id=model_id, + name="utility_model_test", + description="utility_model_test", + code="s3://bucket/path/to/code", + output_examples="output_description", + function=Function.UTILITIES, + api_key=config.TEAM_API_KEY, + ) + utility_model.validate() # Should not raise any exception + +def test_model_exists_success(): + """Test _model_exists when model exists""" + with requests_mock.Mocker() as mock: + model_id = "123" + url = urljoin(config.BACKEND_URL, f"sdk/models/{model_id}") + mock.get(url, status_code=200) + + utility_model = UtilityModel( + id=model_id, + name="utility_model_test", + description="utility_model_test", + code="def main(originCode: str)", + output_examples="output_description", + function=Function.UTILITIES, + api_key=config.TEAM_API_KEY, + ) + assert utility_model._model_exists() is True + +def test_model_exists_failure(): + """Test _model_exists when model doesn't exist""" + with requests_mock.Mocker() as mock: + model_id = "123" + url = urljoin(config.BACKEND_URL, f"sdk/models/{model_id}") + mock.get(url, status_code=404) + + utility_model = UtilityModel( + id=model_id, + name="utility_model_test", + description="utility_model_test", + code="def main(originCode: str)", + output_examples="output_description", + function=Function.UTILITIES, + api_key=config.TEAM_API_KEY, + ) + with pytest.raises(Exception): + utility_model._model_exists() + +def test_model_exists_empty_id(): + """Test _model_exists with empty ID""" + utility_model = UtilityModel( + id="", # Empty ID + name="utility_model_test", + description="utility_model_test", + code="def main(originCode: str)", + output_examples="output_description", + function=Function.UTILITIES, + api_key=config.TEAM_API_KEY, + ) + assert utility_model._model_exists() is False