diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a79973ee..c6b06079 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: hooks: - id: pytest-check name: pytest-check - entry: coverage run -m pytest tests/unit + entry: coverage run --source=. -m pytest tests/unit language: python pass_filenames: false types: [python] diff --git a/aixplain/enums/__init__.py b/aixplain/enums/__init__.py index d66facce..947d59a9 100644 --- a/aixplain/enums/__init__.py +++ b/aixplain/enums/__init__.py @@ -13,4 +13,4 @@ from .supplier import Supplier from .sort_by import SortBy from .sort_order import SortOrder -from .model_status import ModelStatus +from .response_status import ResponseStatus \ No newline at end of file diff --git a/aixplain/enums/model_status.py b/aixplain/enums/model_status.py deleted file mode 100644 index af4ae0a9..00000000 --- a/aixplain/enums/model_status.py +++ /dev/null @@ -1,11 +0,0 @@ -from enum import Enum -from typing import Text - - -class ModelStatus(Text, Enum): - FAILED = "FAILED" - IN_PROGRESS = "IN_PROGRESS" - SUCCESS = "SUCCESS" - - def __str__(self): - return self._value_ diff --git a/aixplain/enums/response_status.py b/aixplain/enums/response_status.py new file mode 100644 index 00000000..d2810753 --- /dev/null +++ b/aixplain/enums/response_status.py @@ -0,0 +1,31 @@ +__author__ = "thiagocastroferreira" + +""" +Copyright 2024 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli +Date: February 21st 2024 +Description: + Asset Enum +""" + +from enum import Enum +from typing import Text + + +class ResponseStatus(Text, Enum): + IN_PROGRESS = "IN_PROGRESS" + SUCCESS = "SUCCESS" + FAILED = "FAILED" diff --git a/aixplain/enums/supplier.py b/aixplain/enums/supplier.py index 5d3e137d..ecc29998 100644 --- a/aixplain/enums/supplier.py +++ b/aixplain/enums/supplier.py @@ -48,6 +48,7 @@ def load_suppliers(): headers = {"x-aixplain-key": aixplain_key, "Content-Type": "application/json"} else: headers = {"x-api-key": api_key, "Content-Type": "application/json"} + logging.debug(f"Start service for GET API Creation - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) if not 200 <= r.status_code < 300: raise Exception( diff --git a/aixplain/factories/file_factory.py b/aixplain/factories/file_factory.py index 2085c75d..1a29ac11 100644 --- a/aixplain/factories/file_factory.py +++ b/aixplain/factories/file_factory.py @@ -91,7 +91,7 @@ def check_storage_type(cls, input_link: Any) -> StorageType: Returns: StorageType: URL, TEXT or FILE """ - if os.path.exists(input_link) is True: + if os.path.exists(input_link) is True and os.path.isfile(input_link) is True: return StorageType.FILE elif ( input_link.startswith("s3://") diff --git a/aixplain/factories/model_factory.py b/aixplain/factories/model_factory.py index 209ff75d..b6588023 100644 --- a/aixplain/factories/model_factory.py +++ b/aixplain/factories/model_factory.py @@ -222,7 +222,7 @@ def _get_assets_from_page( @classmethod def list( cls, - function: Function, + function: Optional[Function] = None, query: Optional[Text] = "", suppliers: Optional[Union[Supplier, List[Supplier]]] = None, source_languages: Optional[Union[Language, List[Language]]] = None, diff --git a/aixplain/factories/pipeline_factory/utils.py b/aixplain/factories/pipeline_factory/utils.py index c9291031..08954571 100644 --- a/aixplain/factories/pipeline_factory/utils.py +++ b/aixplain/factories/pipeline_factory/utils.py @@ -88,7 +88,7 @@ def build_from_response(response: Dict, load_architecture: bool = False) -> Pipe data_type=custom_input.get("dataType"), code=custom_input["code"], value=custom_input.get("value"), - is_required=custom_input.get("isRequired", False), + is_required=custom_input.get("isRequired", True), ) node.number = node_json["number"] node.label = node_json["label"] diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 41abf865..104bcb62 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -32,7 +32,7 @@ from typing import Union, Optional, Text, Dict from datetime import datetime from aixplain.modules.model.response import ModelResponse -from aixplain.enums import ModelStatus +from aixplain.enums.response_status import ResponseStatus class Model(Asset): @@ -106,6 +106,7 @@ def to_dict(self) -> Dict: return { "id": self.id, "name": self.name, + "description": self.description, "supplier": self.supplier, "additional_info": clean_additional_info, "input_params": self.input_params, @@ -118,7 +119,9 @@ def __repr__(self): except Exception: return f"" - def sync_poll(self, poll_url: Text, name: Text = "model_process", wait_time: float = 0.5, timeout: float = 300) -> Dict: + def sync_poll( + self, poll_url: Text, name: Text = "model_process", wait_time: float = 0.5, timeout: float = 300 + ) -> ModelResponse: """Keeps polling the platform to check whether an asynchronous call is done. Args: @@ -135,7 +138,7 @@ def sync_poll(self, poll_url: Text, name: Text = "model_process", wait_time: flo # keep wait time as 0.2 seconds the minimum wait_time = max(wait_time, 0.2) completed = False - response_body = {"status": "FAILED", "completed": False} + response_body = ModelResponse(status=ResponseStatus.FAILED, completed=False) while not completed and (end - start) < timeout: try: response_body = self.poll(poll_url, name=name) @@ -147,13 +150,17 @@ def sync_poll(self, poll_url: Text, name: Text = "model_process", wait_time: flo if wait_time < 60: wait_time *= 1.1 except Exception as e: - response_body = {"status": "FAILED", "completed": False, "error_message": "No response from the service."} + response_body = ModelResponse( + status=ResponseStatus.FAILED, completed=False, error_message="No response from the service." + ) logging.error(f"Polling for Model: polling for {name}: {e}") break if response_body["completed"] is True: logging.debug(f"Polling for Model: Final status of polling for {name}: {response_body}") else: - response_body["status"] = "FAILED" + response_body = ModelResponse( + status=ResponseStatus.FAILED, completed=False, error_message="No response from the service." + ) logging.error( f"Polling for Model: Final status of polling for {name}: No response in {timeout} seconds - {response_body}" ) @@ -174,11 +181,11 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> ModelResponse: try: resp = r.json() if resp["completed"] is True: - status = ModelStatus.SUCCESS + status = ResponseStatus.SUCCESS if "error_message" in resp or "supplierError" in resp: - status = ModelStatus.FAILED + status = ResponseStatus.FAILED else: - status = ModelStatus.IN_PROGRESS + status = ResponseStatus.IN_PROGRESS logging.debug(f"Single Poll for Model: Status of polling for {name}: {resp}") return ModelResponse( status=resp.pop("status", status), @@ -195,7 +202,7 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> ModelResponse: resp = {"status": "FAILED"} logging.error(f"Single Poll for Model: Error of polling for {name}: {e}") return ModelResponse( - status=ModelStatus.FAILED, + status=ResponseStatus.FAILED, error_message=str(e), completed=False, ) @@ -205,7 +212,7 @@ def run( data: Union[Text, Dict], name: Text = "model_process", timeout: float = 300, - parameters: Optional[Dict] = {}, + parameters: Optional[Dict] = None, wait_time: float = 0.5, ) -> ModelResponse: """Runs a model call. @@ -214,7 +221,7 @@ def run( data (Union[Text, Dict]): link to the input data name (Text, optional): ID given to a call. Defaults to "model_process". timeout (float, optional): total polling time. Defaults to 300. - parameters (Dict, optional): optional parameters to the model. Defaults to "{}". + parameters (Dict, optional): optional parameters to the model. Defaults to None. wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5. Returns: @@ -234,9 +241,9 @@ def run( msg = f"Error in request for {name} - {traceback.format_exc()}" logging.error(f"Model Run: Error in running for {name}: {e}") end = time.time() - response = {"status": "FAILED", "error": msg, "runTime": end - start} + response = {"status": "FAILED", "error_message": msg, "runTime": end - start} return ModelResponse( - status=response.pop("status", ModelStatus.FAILED), + status=response.pop("status", ResponseStatus.FAILED), data=response.pop("data", ""), details=response.pop("details", {}), completed=response.pop("completed", False), @@ -247,13 +254,15 @@ def run( **response, ) - def run_async(self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = {}) -> ModelResponse: + def run_async( + self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = None + ) -> ModelResponse: """Runs asynchronously a model call. Args: data (Union[Text, Dict]): link to the input data name (Text, optional): ID given to a call. Defaults to "model_process". - parameters (Dict, optional): optional parameters to the model. Defaults to "{}". + parameters (Dict, optional): optional parameters to the model. Defaults to None. Returns: dict: polling URL in response @@ -263,7 +272,7 @@ def run_async(self, data: Union[Text, Dict], name: Text = "model_process", param payload = build_payload(data=data, parameters=parameters) response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key) return ModelResponse( - status=response.pop("status", ModelStatus.FAILED), + status=response.pop("status", ResponseStatus.FAILED), data=response.pop("data", ""), details=response.pop("details", {}), completed=response.pop("completed", False), diff --git a/aixplain/modules/model/llm_model.py b/aixplain/modules/model/llm_model.py index 941c4a6f..600fd32e 100644 --- a/aixplain/modules/model/llm_model.py +++ b/aixplain/modules/model/llm_model.py @@ -29,7 +29,7 @@ from aixplain.utils import config from typing import Union, Optional, List, Text, Dict from aixplain.modules.model.response import ModelResponse -from aixplain.enums import ModelStatus +from aixplain.enums.response_status import ResponseStatus class LLM(Model): @@ -104,7 +104,7 @@ def run( top_p: float = 1.0, name: Text = "model_process", timeout: float = 300, - parameters: Optional[Dict] = {}, + parameters: Optional[Dict] = None, wait_time: float = 0.5, ) -> ModelResponse: """Synchronously running a Large Language Model (LLM) model. @@ -119,21 +119,23 @@ def run( top_p (float, optional): Top P. Defaults to 1.0. name (Text, optional): ID given to a call. Defaults to "model_process". timeout (float, optional): total polling time. Defaults to 300. - parameters (Dict, optional): optional parameters to the model. Defaults to "{}". + parameters (Dict, optional): optional parameters to the model. Defaults to None. wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5. Returns: Dict: parsed output from model """ start = time.time() + if parameters is None: + parameters = {} parameters.update( { - "context": parameters["context"] if "context" in parameters else context, - "prompt": parameters["prompt"] if "prompt" in parameters else prompt, - "history": parameters["history"] if "history" in parameters else history, - "temperature": parameters["temperature"] if "temperature" in parameters else temperature, - "max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens, - "top_p": parameters["top_p"] if "top_p" in parameters else top_p, + "context": parameters.get("context", context), + "prompt": parameters.get("prompt", prompt), + "history": parameters.get("history", history), + "temperature": parameters.get("temperature", temperature), + "max_tokens": parameters.get("max_tokens", max_tokens), + "top_p": parameters.get("top_p", top_p), } ) payload = build_payload(data=data, parameters=parameters) @@ -152,7 +154,7 @@ def run( end = time.time() response = {"status": "FAILED", "error": msg, "elapsed_time": end - start} return ModelResponse( - status=response.pop("status", ModelStatus.FAILED), + status=response.pop("status", ResponseStatus.FAILED), data=response.pop("data", ""), details=response.pop("details", {}), completed=response.pop("completed", False), @@ -173,7 +175,7 @@ def run_async( max_tokens: int = 128, top_p: float = 1.0, name: Text = "model_process", - parameters: Optional[Dict] = {}, + parameters: Optional[Dict] = None, ) -> ModelResponse: """Runs asynchronously a model call. @@ -186,27 +188,29 @@ def run_async( max_tokens (int, optional): Maximum Generation Tokens. Defaults to 128. top_p (float, optional): Top P. Defaults to 1.0. name (Text, optional): ID given to a call. Defaults to "model_process". - parameters (Dict, optional): optional parameters to the model. Defaults to "{}". + parameters (Dict, optional): optional parameters to the model. Defaults to None. Returns: dict: polling URL in response """ url = f"{self.url}/{self.id}" logging.debug(f"Model Run Async: Start service for {name} - {url}") + if parameters is None: + parameters = {} parameters.update( { - "context": parameters["context"] if "context" in parameters else context, - "prompt": parameters["prompt"] if "prompt" in parameters else prompt, - "history": parameters["history"] if "history" in parameters else history, - "temperature": parameters["temperature"] if "temperature" in parameters else temperature, - "max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens, - "top_p": parameters["top_p"] if "top_p" in parameters else top_p, + "context": parameters.get("context", context), + "prompt": parameters.get("prompt", prompt), + "history": parameters.get("history", history), + "temperature": parameters.get("temperature", temperature), + "max_tokens": parameters.get("max_tokens", max_tokens), + "top_p": parameters.get("top_p", top_p), } ) payload = build_payload(data=data, parameters=parameters) response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key) return ModelResponse( - status=response.pop("status", ModelStatus.FAILED), + status=response.pop("status", ResponseStatus.FAILED), data=response.pop("data", ""), details=response.pop("details", {}), completed=response.pop("completed", False), diff --git a/aixplain/modules/model/response.py b/aixplain/modules/model/response.py index 42ed09a4..94ddcb9d 100644 --- a/aixplain/modules/model/response.py +++ b/aixplain/modules/model/response.py @@ -1,15 +1,13 @@ -from dataclasses import dataclass from typing import Text, Any, Optional, Dict, List, Union -from aixplain.enums import ModelStatus +from aixplain.enums import ResponseStatus -@dataclass class ModelResponse: """ModelResponse class to store the response of the model run.""" def __init__( self, - status: ModelStatus, + status: ResponseStatus, data: Text = "", details: Optional[Union[Dict, List]] = {}, completed: bool = False, diff --git a/aixplain/modules/model/utils.py b/aixplain/modules/model/utils.py index 2235b35a..13cc1f7c 100644 --- a/aixplain/modules/model/utils.py +++ b/aixplain/modules/model/utils.py @@ -3,12 +3,15 @@ import json import logging from aixplain.utils.file_utils import _request_with_retry -from typing import Dict, Text, Union +from typing import Dict, Text, Union, Optional -def build_payload(data: Union[Text, Dict], parameters: Dict = {}): +def build_payload(data: Union[Text, Dict], parameters: Optional[Dict] = None): from aixplain.factories import FileFactory + if parameters is None: + parameters = {} + data = FileFactory.to_link(data) if isinstance(data, dict): payload = data diff --git a/aixplain/modules/pipeline/designer/base.py b/aixplain/modules/pipeline/designer/base.py index a925840f..08d4c8c5 100644 --- a/aixplain/modules/pipeline/designer/base.py +++ b/aixplain/modules/pipeline/designer/base.py @@ -1,3 +1,4 @@ +import re from typing import ( List, Union, @@ -11,7 +12,7 @@ from aixplain.enums import DataType from .enums import NodeType, ParamType - +from .utils import find_prompt_params if TYPE_CHECKING: from .pipeline import DesignerPipeline @@ -280,14 +281,31 @@ def __getitem__(self, code: str) -> Param: return param raise KeyError(f"Parameter with code '{code}' not found.") + def special_prompt_handling(self, code: str, value: str) -> None: + """ + This method will handle the special prompt handling for asset nodes + having `text-generation` function type. + """ + from .nodes import AssetNode + + if isinstance(self.node, AssetNode) and self.node.asset.function == "text-generation": + if code == "prompt": + matches = find_prompt_params(value) + for match in matches: + self.node.inputs.create_param(match, DataType.TEXT, is_required=True) + + def set_param_value(self, code: str, value: str) -> None: + self.special_prompt_handling(code, value) + self[code].value = value + def __setitem__(self, code: str, value: str) -> None: # set param value on set item to avoid setting it manually - self[code].value = value + self.set_param_value(code, value) def __setattr__(self, name: str, value: any) -> None: # set param value on attribute assignment to avoid setting it manually if isinstance(value, str) and hasattr(self, name): - self[name].value = value + self.set_param_value(name, value) else: super().__setattr__(name, value) diff --git a/aixplain/modules/pipeline/designer/pipeline.py b/aixplain/modules/pipeline/designer/pipeline.py index ece5ac0c..79013590 100644 --- a/aixplain/modules/pipeline/designer/pipeline.py +++ b/aixplain/modules/pipeline/designer/pipeline.py @@ -6,7 +6,7 @@ from .nodes import AssetNode, Decision, Script, Input, Output, Router, Route, BareReconstructor, BareSegmentor, BareMetric from .enums import NodeType, RouteType, Operation from .mixins import OutputableMixin - +from .utils import find_prompt_params T = TypeVar("T", bound="AssetNode") @@ -125,6 +125,24 @@ def is_param_set(self, node, param): """ return param.value or self.is_param_linked(node, param) + def special_prompt_validation(self, node: Node): + """ + This method will handle the special rule for asset nodes having + `text-generation` function type where if any prompt variable exists + then the `text` param is not required but the prompt param are. + + :param node: the node + :raises ValueError: if the pipeline is not valid + """ + if isinstance(node, AssetNode) and node.asset.function == "text-generation": + if self.is_param_set(node, node.inputs.prompt): + matches = find_prompt_params(node.inputs.prompt.value) + if matches: + node.inputs.text.is_required = False + for match in matches: + if match not in node.inputs: + raise ValueError(f"Param {match} of node {node.label} should be defined and set") + def validate_params(self): """ This method will check if all required params are either set or linked @@ -132,6 +150,7 @@ def validate_params(self): :raises ValueError: if the pipeline is not valid """ for node in self.nodes: + self.special_prompt_validation(node) for param in node.inputs: if param.is_required and not self.is_param_set(node, param): raise ValueError(f"Param {param.code} of node {node.label} is required") diff --git a/aixplain/modules/pipeline/designer/utils.py b/aixplain/modules/pipeline/designer/utils.py new file mode 100644 index 00000000..250d5501 --- /dev/null +++ b/aixplain/modules/pipeline/designer/utils.py @@ -0,0 +1,13 @@ +import re +from typing import List + + +def find_prompt_params(prompt: str) -> List[str]: + """ + This method will find the prompt parameters in the prompt string. + + :param prompt: the prompt string + :return: list of prompt parameters + """ + param_regex = re.compile(r"\{\{([^\}]+)\}\}") + return param_regex.findall(prompt) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..a03eea30 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,4 @@ +from dotenv import load_dotenv + +# Load environment variables once for all tests +load_dotenv() diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 5b247728..55d671e0 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -39,6 +39,7 @@ def read_data(data_path): def run_input_map(request): return request.param + @pytest.fixture(scope="function") def delete_agents_and_team_agents(): for team_agent in TeamAgentFactory.list()["results"]: @@ -100,12 +101,8 @@ def test_list_agents(): assert type(agents_result) is list -def test_update_draft_agent(run_input_map): - for team in TeamAgentFactory.list()["results"]: - team.delete() - - for agent in AgentFactory.list()["results"]: - agent.delete() +def test_update_draft_agent(run_input_map, delete_agents_and_team_agents): + assert delete_agents_and_team_agents tools = [] if "model_tools" in run_input_map: @@ -137,7 +134,8 @@ def test_update_draft_agent(run_input_map): agent.delete() -def test_fail_non_existent_llm(): +def test_fail_non_existent_llm(delete_agents_and_team_agents): + assert delete_agents_and_team_agents with pytest.raises(Exception) as exc_info: AgentFactory.create( name="Test Agent", @@ -147,6 +145,7 @@ def test_fail_non_existent_llm(): ) assert str(exc_info.value) == "Large Language Model with ID 'non_existent_llm' not found." + def test_delete_agent_in_use(delete_agents_and_team_agents): assert delete_agents_and_team_agents agent = AgentFactory.create( @@ -160,7 +159,7 @@ def test_delete_agent_in_use(delete_agents_and_team_agents): description="Test description", use_mentalist_and_inspector=True, ) - + with pytest.raises(Exception) as exc_info: agent.delete() - assert str(exc_info.value) == "Agent Deletion Error (HTTP 403): err.agent_is_in_use." \ No newline at end of file + assert str(exc_info.value) == "Agent Deletion Error (HTTP 403): err.agent_is_in_use." diff --git a/tests/functional/pipelines/run_test.py b/tests/functional/pipelines/run_test.py index dbdb76fa..6ca9e6fe 100644 --- a/tests/functional/pipelines/run_test.py +++ b/tests/functional/pipelines/run_test.py @@ -241,3 +241,13 @@ def test_run_decision(input_data: str, output_data: str, version: str): assert response["status"] == "SUCCESS" assert response["data"][0]["label"] == output_data + + +@pytest.mark.parametrize("version", ["3.0"]) +def test_run_script(version: str): + pipeline = PipelineFactory.list(query="Script Functional Test - DO NOT DELETE")["results"][0] + response = pipeline.run("https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", **{"version": version}) + + assert response["status"] == "SUCCESS" + data = response["data"][0]["segments"][0]["response"] + assert data.startswith("SCRIPT MODIFIED:") diff --git a/tests/unit/designer_unit_test.py b/tests/unit/designer_unit_test.py index 824fd162..57276a20 100644 --- a/tests/unit/designer_unit_test.py +++ b/tests/unit/designer_unit_test.py @@ -1,6 +1,5 @@ import pytest -import unittest.mock as mock - +from unittest.mock import patch, Mock, call from aixplain.enums import DataType from aixplain.modules.pipeline.designer.base import ( @@ -21,7 +20,7 @@ from aixplain.modules.pipeline.designer.mixins import LinkableMixin from aixplain.modules.pipeline.designer.pipeline import DesignerPipeline - +from aixplain.modules.pipeline.designer.base import find_prompt_params def test_create_node(): @@ -30,7 +29,7 @@ def test_create_node(): class BareNode(Node): pass - with mock.patch("aixplain.modules.pipeline.designer.Node.attach_to") as mock_attach_to: + with patch("aixplain.modules.pipeline.designer.Node.attach_to") as mock_attach_to: node = BareNode(number=3, label="FOO") mock_attach_to.assert_not_called() assert isinstance(node.inputs, Inputs) @@ -48,7 +47,7 @@ class FooNode(Node[FooNodeInputs, FooNodeOutputs]): inputs_class = FooNodeInputs outputs_class = FooNodeOutputs - with mock.patch("aixplain.modules.pipeline.designer.Node.attach_to") as mock_attach_to: + with patch("aixplain.modules.pipeline.designer.Node.attach_to") as mock_attach_to: node = FooNode(pipeline=pipeline, number=3, label="FOO") mock_attach_to.assert_called_once_with(pipeline) assert isinstance(node.inputs, FooNodeInputs) @@ -115,8 +114,8 @@ class AssetNode(Node): node = AssetNode() - with mock.patch.object(node.inputs, "serialize") as mock_inputs_serialize: - with mock.patch.object(node.outputs, "serialize") as mock_outputs_serialize: + with patch.object(node.inputs, "serialize") as mock_inputs_serialize: + with patch.object(node.outputs, "serialize") as mock_outputs_serialize: assert node.serialize() == { "number": node.number, "type": NodeType.ASSET, @@ -145,7 +144,7 @@ def test_create_param(): class TypedParam(Param): param_type = ParamType.INPUT - with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: + with patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: param = TypedParam( code="param", data_type=DataType.TEXT, @@ -158,7 +157,7 @@ class TypedParam(Param): assert param.value == "foo" assert param.param_type == ParamType.INPUT - with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: + with patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: param = TypedParam( code="param", data_type=DataType.TEXT, @@ -175,7 +174,7 @@ class TypedParam(Param): class UnTypedParam(Param): pass - with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: + with patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: param = UnTypedParam( code="param", data_type=DataType.TEXT, @@ -186,7 +185,7 @@ class UnTypedParam(Param): assert param.param_type == ParamType.OUTPUT - with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: + with patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: param = UnTypedParam( code="param", data_type=DataType.TEXT, @@ -202,7 +201,7 @@ class AssetNode(Node): node = AssetNode() - with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: + with patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: param = UnTypedParam( code="param", data_type=DataType.TEXT, @@ -226,7 +225,7 @@ class AssetNode(Node): node = AssetNode() - with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: + with patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: param = param_cls(code="param", data_type=DataType.TEXT, value="foo", node=node) mock_attach_to.assert_called_once_with(node) assert param.code == "param" @@ -253,7 +252,7 @@ class NoTypeParam(Param): input = InputParam(code="input", data_type=DataType.TEXT, value="foo") - with mock.patch.object(node.inputs, "add_param") as mock_add_param: + with patch.object(node.inputs, "add_param") as mock_add_param: input.attach_to(node) mock_add_param.assert_called_once_with(input) assert input.node is node @@ -265,7 +264,7 @@ class NoTypeParam(Param): output = OutputParam(code="output", data_type=DataType.TEXT, value="bar") - with mock.patch.object(node.outputs, "add_param") as mock_add_param: + with patch.object(node.outputs, "add_param") as mock_add_param: output.attach_to(node) mock_add_param.assert_called_once_with(output) assert output.node is node @@ -304,7 +303,7 @@ class AssetNode(Node, LinkableMixin): output = OutputParam(code="output", data_type=DataType.TEXT, value="bar", node=a) input = InputParam(code="input", data_type=DataType.TEXT, value="foo", node=b) - with mock.patch.object(input, "back_link") as mock_back_link: + with patch.object(input, "back_link") as mock_back_link: output.link(input) mock_back_link.assert_called_once_with(output) @@ -342,7 +341,7 @@ class AssetNode(Node, LinkableMixin): output = OutputParam(code="output", data_type=DataType.TEXT, value="bar", node=a) input = InputParam(code="input", data_type=DataType.TEXT, value="foo", node=b) - with mock.patch.object(a, "link") as mock_link: + with patch.object(a, "link") as mock_link: input.back_link(output) mock_link.assert_called_once_with(b, output, input) @@ -400,7 +399,7 @@ class AssetNode(Node, LinkableMixin): pipeline = DesignerPipeline() - with mock.patch("aixplain.modules.pipeline.designer.Link.attach_to") as mock_attach_to: + with patch("aixplain.modules.pipeline.designer.Link.attach_to") as mock_attach_to: link = Link( from_node=a, to_node=b, @@ -431,8 +430,8 @@ class AssetNode(Node, LinkableMixin): to_param="input", ) - with mock.patch.object(a, "attach_to") as mock_a_attach_to: - with mock.patch.object(b, "attach_to") as mock_b_attach_to: + with patch.object(a, "attach_to") as mock_a_attach_to: + with patch.object(b, "attach_to") as mock_b_attach_to: link.attach_to(pipeline) mock_a_attach_to.assert_called_once_with(pipeline) mock_b_attach_to.assert_called_once_with(pipeline) @@ -451,8 +450,8 @@ class AssetNode(Node, LinkableMixin): to_param="input", ) - with mock.patch.object(a, "attach_to") as mock_a_attach_to: - with mock.patch.object(b, "attach_to") as mock_b_attach_to: + with patch.object(a, "attach_to") as mock_a_attach_to: + with patch.object(b, "attach_to") as mock_b_attach_to: link.attach_to(pipeline) mock_a_attach_to.assert_not_called() mock_b_attach_to.assert_not_called() @@ -555,8 +554,8 @@ class AssetNode(Node): param_proxy = ParamProxy(node) - with mock.patch.object(param_proxy, "_create_param") as mock_create_param: - with mock.patch.object(param_proxy, "add_param") as mock_add_param: + with patch.object(param_proxy, "_create_param") as mock_create_param: + with patch.object(param_proxy, "add_param") as mock_add_param: param = param_proxy.create_param("foo", DataType.TEXT, "bar", is_required=True) mock_create_param.assert_called_once_with("foo", DataType.TEXT, "bar") mock_add_param.assert_called_once_with(param) @@ -588,6 +587,48 @@ class FooParam(Param): assert "'bar'" in str(excinfo.value) +def test_param_proxy_set_param_value(): + prompt_param = Mock(spec=Param, code="prompt") + param_proxy = ParamProxy(Mock()) + param_proxy._params = [prompt_param] + with patch.object(param_proxy, "special_prompt_handling") as mock_special_prompt_handling: + param_proxy.set_param_value("prompt", "hello {{foo}}") + mock_special_prompt_handling.assert_called_once_with("prompt", "hello {{foo}}") + assert prompt_param.value == "hello {{foo}}" + + +def test_param_proxy_special_prompt_handling(): + from aixplain.modules.pipeline.designer.nodes import AssetNode + + asset_node = Mock(spec=AssetNode, asset=Mock(function="text-generation")) + param_proxy = ParamProxy(asset_node) + with patch( + "aixplain.modules.pipeline.designer.base.find_prompt_params" + ) as mock_find_prompt_params: + mock_find_prompt_params.return_value = [] + param_proxy.special_prompt_handling("prompt", "hello {{foo}}") + mock_find_prompt_params.assert_called_once_with("hello {{foo}}") + asset_node.inputs.create_param.assert_not_called() + asset_node.reset_mock() + mock_find_prompt_params.reset_mock() + + mock_find_prompt_params.return_value = ["foo"] + param_proxy.special_prompt_handling("prompt", "hello {{foo}}") + mock_find_prompt_params.assert_called_once_with("hello {{foo}}") + asset_node.inputs.create_param.assert_called_once_with("foo", DataType.TEXT, is_required=True) + asset_node.reset_mock() + mock_find_prompt_params.reset_mock() + + mock_find_prompt_params.return_value = ["foo", "bar"] + param_proxy.special_prompt_handling("prompt", "hello {{foo}} {{bar}}") + mock_find_prompt_params.assert_called_once_with("hello {{foo}} {{bar}}") + assert asset_node.inputs.create_param.call_count == 2 + assert asset_node.inputs.create_param.call_args_list == [ + call("foo", DataType.TEXT, is_required=True), + call("bar", DataType.TEXT, is_required=True), + ] + + def test_node_link(): class AssetNode(Node, LinkableMixin): type: NodeType = NodeType.ASSET @@ -623,7 +664,7 @@ class AssetNode(Node): type: NodeType = NodeType.ASSET node1 = AssetNode() - with mock.patch.object(node1, "attach_to") as mock_attach_to: + with patch.object(node1, "attach_to") as mock_attach_to: pipeline.add_node(node1) mock_attach_to.assert_called_once_with(pipeline) @@ -636,14 +677,14 @@ class InputNode(Node): node = InputNode() - with mock.patch.object(pipeline, "add_node") as mock_add_node: + with patch.object(pipeline, "add_node") as mock_add_node: pipeline.add_nodes(node) assert mock_add_node.call_count == 1 node1 = InputNode() node2 = InputNode() - with mock.patch.object(pipeline, "add_node") as mock_add_node: + with patch.object(pipeline, "add_node") as mock_add_node: pipeline.add_nodes(node1, node2) assert mock_add_node.call_count == 2 @@ -662,6 +703,95 @@ class AssetNode(Node): link = Link(from_node=a, to_node=b, from_param="output", to_param="input") pipeline.add_link(link) - with mock.patch.object(link, "attach_to") as mock_attach_to: + with patch.object(link, "attach_to") as mock_attach_to: pipeline.add_link(link) mock_attach_to.assert_called_once_with(pipeline) + + +def test_pipeline_special_prompt_validation(): + from aixplain.modules.pipeline.designer.nodes import AssetNode + + pipeline = DesignerPipeline() + asset_node = Mock( + spec=AssetNode, + label="LLM(ID=1)", + asset=Mock(function="text-generation"), + inputs=Mock(prompt=Mock(value="hello {{foo}}"), text=Mock(is_required=True)), + ) + with patch.object(pipeline, "is_param_set") as mock_is_param_set: + mock_is_param_set.return_value = False + pipeline.special_prompt_validation(asset_node) + mock_is_param_set.assert_called_once_with(asset_node, asset_node.inputs.prompt) + assert asset_node.inputs.text.is_required is True + mock_is_param_set.reset_mock() + mock_is_param_set.return_value = True + with patch( + "aixplain.modules.pipeline.designer.pipeline.find_prompt_params" + ) as mock_find_prompt_params: + mock_find_prompt_params.return_value = [] + pipeline.special_prompt_validation(asset_node) + mock_is_param_set.assert_called_once_with( + asset_node, asset_node.inputs.prompt + ) + mock_find_prompt_params.assert_called_once_with( + asset_node.inputs.prompt.value + ) + assert asset_node.inputs.text.is_required is True + + mock_is_param_set.reset_mock() + mock_is_param_set.return_value = True + mock_find_prompt_params.reset_mock() + mock_find_prompt_params.return_value = ["foo"] + asset_node.inputs.__contains__ = Mock(return_value=False) + + with pytest.raises( + ValueError, + match="Param foo of node LLM\\(ID=1\\) should be defined and set", + ): + pipeline.special_prompt_validation(asset_node) + + mock_is_param_set.assert_called_once_with( + asset_node, asset_node.inputs.prompt + ) + mock_find_prompt_params.assert_called_once_with( + asset_node.inputs.prompt.value + ) + assert asset_node.inputs.text.is_required is False + + mock_is_param_set.reset_mock() + mock_is_param_set.return_value = True + mock_find_prompt_params.reset_mock() + mock_find_prompt_params.return_value = ["foo"] + asset_node.inputs.text.is_required = True + + asset_node.inputs.__contains__ = Mock(return_value=True) + pipeline.special_prompt_validation(asset_node) + mock_is_param_set.assert_called_once_with( + asset_node, asset_node.inputs.prompt + ) + mock_find_prompt_params.assert_called_once_with( + asset_node.inputs.prompt.value + ) + assert asset_node.inputs.text.is_required is False + + +@pytest.mark.parametrize( + "input, expected", + [ + ("hello {{foo}}", ["foo"]), + ("hello {{foo}} {{bar}}", ["foo", "bar"]), + ("hello {{foo}} {{bar}} {{baz}}", ["foo", "bar", "baz"]), + # no match cases + ("hello bar", []), + ("hello {{foo]] bar", []), + ("hello {foo} bar", []), + # edge cases + ("", []), + ("{{}}", []), + # interesting cases + ("hello {{foo {{bar}} baz}} {{bar}} {{baz}}", ["foo {{bar", "bar", "baz"]), + ], +) +def test_find_prompt_params(input, expected): + print(input, expected) + assert find_prompt_params(input) == expected diff --git a/tests/unit/llm_test.py b/tests/unit/llm_test.py index 1329e136..753a8f7a 100644 --- a/tests/unit/llm_test.py +++ b/tests/unit/llm_test.py @@ -4,7 +4,7 @@ load_dotenv() from aixplain.utils import config -from aixplain.enums import ModelStatus +from aixplain.enums import ResponseStatus from aixplain.modules.model.response import ModelResponse from aixplain.modules import LLM @@ -85,9 +85,76 @@ def test_run_sync(): response = test_model.run(data=input_data, temperature=0.001, max_tokens=128, top_p=1.0) assert isinstance(response, ModelResponse) - assert response.status == ModelStatus.SUCCESS + assert response.status == ResponseStatus.SUCCESS assert response.data == "Test Model Result" assert response.completed is True assert response.used_credits == 0 assert response.run_time == 0 assert response.usage is None + + +@pytest.mark.skip(reason="Need to fix model response") +def test_run_sync_polling_error(): + """Test handling of polling errors in the run method""" + model_id = "test-model-id" + base_url = config.MODELS_RUN_URL + execute_url = f"{base_url}/{model_id}".replace("/api/v1/execute", "/api/v2/execute") + + ref_response = { + "status": "IN_PROGRESS", + "data": "https://models.aixplain.com/api/v1/data/invalid-id", + } + + with requests_mock.Mocker() as mock: + # Mock the initial execution call + mock.post(execute_url, json=ref_response) + + # Mock the polling URL to raise an exception + poll_url = ref_response["data"] + mock.get(poll_url, exc=Exception("Polling failed")) + + test_model = LLM(id=model_id, name="Test Model", function=Function.TEXT_GENERATION, url=base_url) + + response = test_model.run(data="test input") + + # Updated assertions to match ModelResponse structure + assert isinstance(response, ModelResponse) + assert response.status == ResponseStatus.FAILED + assert response.completed is False + assert "No response from the service" in response.error_message + assert response.data == "" + assert response.used_credits == 0 + assert response.run_time == 0 + assert response.usage is None + + +def test_run_with_custom_parameters(): + """Test run method with custom parameters""" + model_id = "test-model-id" + base_url = config.MODELS_RUN_URL + execute_url = f"{base_url}/{model_id}".replace("/api/v1/execute", "/api/v2/execute") + + ref_response = { + "completed": True, + "status": "SUCCESS", + "data": "Test Result", + "usedCredits": 10, + "runTime": 1.5, + "usage": {"prompt_tokens": 10, "completion_tokens": 20}, + } + + with requests_mock.Mocker() as mock: + mock.post(execute_url, json=ref_response) + + test_model = LLM(id=model_id, name="Test Model", function=Function.TEXT_GENERATION, url=base_url) + + custom_params = {"custom_param": "value", "temperature": 0.8} # This should override the default + + response = test_model.run(data="test input", temperature=0.5, parameters=custom_params) + + assert isinstance(response, ModelResponse) + assert response.status == ResponseStatus.SUCCESS + assert response.data == "Test Result" + assert response.used_credits == 10 + assert response.run_time == 1.5 + assert response.usage == {"prompt_tokens": 10, "completion_tokens": 20} diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index 94e2f6c2..b45b6ae0 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -16,10 +16,8 @@ limitations under the License. """ -from dotenv import load_dotenv import requests_mock -load_dotenv() import json from aixplain.utils import config from aixplain.modules import Model @@ -27,10 +25,11 @@ from aixplain.factories import ModelFactory from aixplain.enums import Function from urllib.parse import urljoin -from aixplain.enums import ModelStatus +from aixplain.enums import ResponseStatus from aixplain.modules.model.response import ModelResponse import pytest from unittest.mock import patch +from aixplain.enums.asset_status import AssetStatus def test_build_payload(): @@ -67,7 +66,7 @@ def test_call_run_endpoint_sync(): model_id = "model-id" execute_url = f"{base_url}/{model_id}".replace("/api/v1/execute", "/api/v2/execute") payload = {"data": "input_data"} - ref_response = {"completed": True, "status": ModelStatus.SUCCESS, "data": "Hello"} + ref_response = {"completed": True, "status": ResponseStatus.SUCCESS, "data": "Hello"} with requests_mock.Mocker() as mock: mock.post(execute_url, json=ref_response) @@ -88,7 +87,7 @@ def test_success_poll(): hyp_response = test_model.poll(poll_url=poll_url) assert isinstance(hyp_response, ModelResponse) assert hyp_response["completed"] == ref_response["completed"] - assert hyp_response["status"] == ModelStatus.SUCCESS + assert hyp_response["status"] == ResponseStatus.SUCCESS def test_failed_poll(): @@ -103,7 +102,7 @@ def test_failed_poll(): response = model.poll(poll_url=poll_url) assert isinstance(response, ModelResponse) - assert response.status == ModelStatus.FAILED + assert response.status == ResponseStatus.FAILED assert response.error_message == "Some error occurred" assert response.completed is True @@ -145,7 +144,7 @@ def test_run_async_errors(status_code, error_message): test_model = Model(id=model_id, name="Test Model", url=base_url) response = test_model.run_async(data="input_data") assert isinstance(response, ModelResponse) - assert response["status"] == ModelStatus.FAILED + assert response["status"] == ResponseStatus.FAILED assert response["error_message"] == error_message @@ -219,7 +218,7 @@ def test_run_sync(): response = test_model.run(data=input_data, name="test_run") assert isinstance(response, ModelResponse) - assert response.status == ModelStatus.SUCCESS + assert response.status == ResponseStatus.SUCCESS assert response.data == "Test Model Result" assert response.completed is True assert response.used_credits == 0 @@ -256,3 +255,234 @@ def test_sync_poll(): assert response["completed"] is True assert response["details"] == {"test": "test"} assert response["data"] == "Polling successful result" + + +def test_run_with_parameters(): + model_id = "test-model-id" + base_url = config.MODELS_RUN_URL + execute_url = f"{base_url}/{model_id}".replace("/api/v1/execute", "/api/v2/execute") + + input_data = "test input" + parameters = {"temperature": 0.7, "max_tokens": 100} + expected_payload = json.dumps({"data": input_data, **parameters}) + + ref_response = { + "completed": True, + "status": "SUCCESS", + "data": "Test Model Result", + "usedCredits": 0, + "runTime": 0, + } + + with requests_mock.Mocker() as mock: + mock.post(execute_url, json=ref_response) + + test_model = Model(id=model_id, name="Test Model", url=base_url, api_key=config.TEAM_API_KEY) + response = test_model.run(data=input_data, parameters=parameters) + + # Verify the payload was constructed correctly + assert mock.last_request.text == expected_payload + assert isinstance(response, ModelResponse) + assert response.status == ResponseStatus.SUCCESS + assert response.data == "Test Model Result" + + +def test_run_async_with_parameters(): + model_id = "test-model-id" + base_url = config.MODELS_RUN_URL + execute_url = f"{base_url}/{model_id}" + + input_data = "test input" + parameters = {"temperature": 0.7, "max_tokens": 100} + expected_payload = json.dumps({"data": input_data, **parameters}) + + ref_response = { + "completed": False, + "status": "IN_PROGRESS", + "data": "https://models.aixplain.com/api/v1/data/test-id", + "url": "https://models.aixplain.com/api/v1/data/test-id", + } + + with requests_mock.Mocker() as mock: + mock.post(execute_url, json=ref_response) + + test_model = Model(id=model_id, name="Test Model", url=base_url, api_key=config.TEAM_API_KEY) + response = test_model.run_async(data=input_data, parameters=parameters) + + # Verify the payload was constructed correctly + assert mock.last_request.text == expected_payload + assert isinstance(response, ModelResponse) + assert response.status == "IN_PROGRESS" + assert response.url == ref_response["url"] + + +def test_successful_delete(): + with requests_mock.Mocker() as mock: + model_id = "test-model-id" + url = urljoin(config.BACKEND_URL, f"sdk/models/{model_id}") + headers = {"Authorization": "Token " + config.TEAM_API_KEY, "Content-Type": "application/json"} + + # Mock successful deletion + mock.delete(url, status_code=200) + + test_model = Model(id=model_id, name="Test Model") + test_model.delete() # Should not raise any exception + + # Verify the request was made with correct headers + assert mock.last_request.headers["Authorization"] == headers["Authorization"] + assert mock.last_request.headers["Content-Type"] == headers["Content-Type"] + + +def test_failed_delete(): + with requests_mock.Mocker() as mock: + model_id = "test-model-id" + url = urljoin(config.BACKEND_URL, f"sdk/models/{model_id}") + + # Mock failed deletion + mock.delete(url, status_code=404) + + test_model = Model(id=model_id, name="Test Model") + + with pytest.raises(Exception) as excinfo: + test_model.delete() + + assert "Model Deletion Error: Make sure the model exists and you are the owner." in str(excinfo.value) + + +def test_model_to_dict(): + # Test with regular additional info + model = Model(id="test-id", name="Test Model", description="", additional_info={"key1": "value1", "key2": None}) + result = model.to_dict() + + # Basic assertions + assert result["id"] == "test-id" + assert result["name"] == "Test Model" + assert result["description"] == "" + + # The additional_info is directly in the result + assert result["additional_info"] == {"additional_info": {"key1": "value1", "key2": None}} + + +def test_model_repr(): + # Test with supplier as dict + model1 = Model(id="test-id", name="Test Model", supplier={"name": "Test Supplier"}) + assert repr(model1) == "" + + # Test with supplier as string + model2 = Model(id="test-id", name="Test Model", supplier="Test Supplier") + assert str(model2) == "" + + +def test_poll_with_error(): + with requests_mock.Mocker() as mock: + poll_url = "https://models.aixplain.com/api/v1/data/test-id" + headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} + + # Mock a response that will cause a JSON decode error + mock.get(poll_url, headers=headers, text="Invalid JSON") + + model = Model(id="test-id", name="Test Model") + response = model.poll(poll_url=poll_url) + + assert isinstance(response, ModelResponse) + assert response.status == ResponseStatus.FAILED + assert "Expecting value: line 1 column 1" in response.error_message + + +def test_sync_poll_with_timeout(): + poll_url = "https://models.aixplain.com/api/v1/data/test-id" + model = Model(id="test-id", name="Test Model") + + # Mock poll method to always return not completed + with patch.object(model, "poll") as mock_poll: + mock_poll.return_value = {"status": "IN_PROGRESS", "completed": False, "error_message": ""} + + # Test with very short timeout + response = model.sync_poll(poll_url=poll_url, timeout=0.1, wait_time=0.2) + + assert response["status"] == "FAILED" + assert response["completed"] is False + + +def test_check_finetune_status_error(): + with requests_mock.Mocker() as mock: + model_id = "test-id" + url = urljoin(config.BACKEND_URL, f"sdk/finetune/{model_id}/ml-logs") + headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} + + # Mock error response + error_response = {"statusCode": 404, "message": "Finetune not found"} + mock.get(url, headers=headers, json=error_response, status_code=404) + + model = Model(id=model_id, name="Test Model") + status = model.check_finetune_status() + + assert status is None + + +def test_check_finetune_status_with_logs(): + with requests_mock.Mocker() as mock: + model_id = "test-id" + url = urljoin(config.BACKEND_URL, f"sdk/finetune/{model_id}/ml-logs") + + # Mock successful response with logs using valid ResponseStatus values + success_response = { + "finetuneStatus": AssetStatus.COMPLETED.value, + "modelStatus": AssetStatus.COMPLETED.value, + "logs": [{"epoch": 1.0, "trainLoss": 0.5, "evalLoss": 0.4}, {"epoch": 2.0, "trainLoss": 0.3, "evalLoss": 0.2}], + } + mock.get(url, json=success_response) + + model = Model(id=model_id, name="Test Model", description="") + + # Test with after_epoch + status = model.check_finetune_status(after_epoch=0) + assert status is not None + assert status.epoch == 1.0 + assert status.training_loss == 0.5 + assert status.validation_loss == 0.4 + + # Test without after_epoch + status = model.check_finetune_status() + assert status is not None + assert status.epoch == 2.0 + assert status.training_loss == 0.3 + assert status.validation_loss == 0.2 + + +def test_check_finetune_status_partial_logs(): + with requests_mock.Mocker() as mock: + model_id = "test-id" + url = urljoin(config.BACKEND_URL, f"sdk/finetune/{model_id}/ml-logs") + + response = { + "finetuneStatus": AssetStatus.IN_PROGRESS.value, + "modelStatus": AssetStatus.IN_PROGRESS.value, + "logs": [{"epoch": 1.0, "trainLoss": 0.5, "evalLoss": 0.4}, {"epoch": 2.0, "trainLoss": 0.3, "evalLoss": 0.2}], + } + mock.get(url, json=response) + + model = Model(id=model_id, name="Test Model", description="") + status = model.check_finetune_status() + + assert status is not None + assert status.epoch == 2.0 + assert status.training_loss == 0.3 + assert status.validation_loss == 0.2 + + +def test_check_finetune_status_no_logs(): + with requests_mock.Mocker() as mock: + model_id = "test-id" + url = urljoin(config.BACKEND_URL, f"sdk/finetune/{model_id}/ml-logs") + + response = {"finetuneStatus": AssetStatus.IN_PROGRESS.value, "modelStatus": AssetStatus.IN_PROGRESS.value, "logs": []} + mock.get(url, json=response) + + model = Model(id=model_id, name="Test Model", description="") + status = model.check_finetune_status() + + assert status is not None + assert status.epoch is None + assert status.training_loss is None + assert status.validation_loss is None