From 27bd338e9153bff29d1da18d95d6cf5cae7036a8 Mon Sep 17 00:00:00 2001 From: Lucas Pavanelli Date: Tue, 5 Nov 2024 10:45:36 -0300 Subject: [PATCH] Set default 'parameters' to None and adding tests --- .pre-commit-config.yaml | 2 +- aixplain/modules/model/__init__.py | 11 +- aixplain/modules/model/llm_model.py | 36 +++-- aixplain/modules/model/utils.py | 7 +- tests/conftest.py | 4 + tests/unit/llm_test.py | 67 ++++++++ tests/unit/model_test.py | 234 +++++++++++++++++++++++++++- 7 files changed, 336 insertions(+), 25 deletions(-) create mode 100644 tests/conftest.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a79973ee..c6b06079 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: hooks: - id: pytest-check name: pytest-check - entry: coverage run -m pytest tests/unit + entry: coverage run --source=. -m pytest tests/unit language: python pass_filenames: false types: [python] diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 41abf865..c1897095 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -106,6 +106,7 @@ def to_dict(self) -> Dict: return { "id": self.id, "name": self.name, + "description": self.description, "supplier": self.supplier, "additional_info": clean_additional_info, "input_params": self.input_params, @@ -205,7 +206,7 @@ def run( data: Union[Text, Dict], name: Text = "model_process", timeout: float = 300, - parameters: Optional[Dict] = {}, + parameters: Optional[Dict] = None, wait_time: float = 0.5, ) -> ModelResponse: """Runs a model call. @@ -214,7 +215,7 @@ def run( data (Union[Text, Dict]): link to the input data name (Text, optional): ID given to a call. Defaults to "model_process". timeout (float, optional): total polling time. Defaults to 300. - parameters (Dict, optional): optional parameters to the model. Defaults to "{}". + parameters (Dict, optional): optional parameters to the model. Defaults to None. wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5. Returns: @@ -247,13 +248,15 @@ def run( **response, ) - def run_async(self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = {}) -> ModelResponse: + def run_async( + self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = None + ) -> ModelResponse: """Runs asynchronously a model call. Args: data (Union[Text, Dict]): link to the input data name (Text, optional): ID given to a call. Defaults to "model_process". - parameters (Dict, optional): optional parameters to the model. Defaults to "{}". + parameters (Dict, optional): optional parameters to the model. Defaults to None. Returns: dict: polling URL in response diff --git a/aixplain/modules/model/llm_model.py b/aixplain/modules/model/llm_model.py index 941c4a6f..846524f4 100644 --- a/aixplain/modules/model/llm_model.py +++ b/aixplain/modules/model/llm_model.py @@ -104,7 +104,7 @@ def run( top_p: float = 1.0, name: Text = "model_process", timeout: float = 300, - parameters: Optional[Dict] = {}, + parameters: Optional[Dict] = None, wait_time: float = 0.5, ) -> ModelResponse: """Synchronously running a Large Language Model (LLM) model. @@ -119,21 +119,23 @@ def run( top_p (float, optional): Top P. Defaults to 1.0. name (Text, optional): ID given to a call. Defaults to "model_process". timeout (float, optional): total polling time. Defaults to 300. - parameters (Dict, optional): optional parameters to the model. Defaults to "{}". + parameters (Dict, optional): optional parameters to the model. Defaults to None. wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5. Returns: Dict: parsed output from model """ start = time.time() + if parameters is None: + parameters = {} parameters.update( { - "context": parameters["context"] if "context" in parameters else context, - "prompt": parameters["prompt"] if "prompt" in parameters else prompt, - "history": parameters["history"] if "history" in parameters else history, - "temperature": parameters["temperature"] if "temperature" in parameters else temperature, - "max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens, - "top_p": parameters["top_p"] if "top_p" in parameters else top_p, + "context": parameters.get("context", context), + "prompt": parameters.get("prompt", prompt), + "history": parameters.get("history", history), + "temperature": parameters.get("temperature", temperature), + "max_tokens": parameters.get("max_tokens", max_tokens), + "top_p": parameters.get("top_p", top_p), } ) payload = build_payload(data=data, parameters=parameters) @@ -173,7 +175,7 @@ def run_async( max_tokens: int = 128, top_p: float = 1.0, name: Text = "model_process", - parameters: Optional[Dict] = {}, + parameters: Optional[Dict] = None, ) -> ModelResponse: """Runs asynchronously a model call. @@ -186,21 +188,23 @@ def run_async( max_tokens (int, optional): Maximum Generation Tokens. Defaults to 128. top_p (float, optional): Top P. Defaults to 1.0. name (Text, optional): ID given to a call. Defaults to "model_process". - parameters (Dict, optional): optional parameters to the model. Defaults to "{}". + parameters (Dict, optional): optional parameters to the model. Defaults to None. Returns: dict: polling URL in response """ url = f"{self.url}/{self.id}" logging.debug(f"Model Run Async: Start service for {name} - {url}") + if parameters is None: + parameters = {} parameters.update( { - "context": parameters["context"] if "context" in parameters else context, - "prompt": parameters["prompt"] if "prompt" in parameters else prompt, - "history": parameters["history"] if "history" in parameters else history, - "temperature": parameters["temperature"] if "temperature" in parameters else temperature, - "max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens, - "top_p": parameters["top_p"] if "top_p" in parameters else top_p, + "context": parameters.get("context", context), + "prompt": parameters.get("prompt", prompt), + "history": parameters.get("history", history), + "temperature": parameters.get("temperature", temperature), + "max_tokens": parameters.get("max_tokens", max_tokens), + "top_p": parameters.get("top_p", top_p), } ) payload = build_payload(data=data, parameters=parameters) diff --git a/aixplain/modules/model/utils.py b/aixplain/modules/model/utils.py index 2235b35a..13cc1f7c 100644 --- a/aixplain/modules/model/utils.py +++ b/aixplain/modules/model/utils.py @@ -3,12 +3,15 @@ import json import logging from aixplain.utils.file_utils import _request_with_retry -from typing import Dict, Text, Union +from typing import Dict, Text, Union, Optional -def build_payload(data: Union[Text, Dict], parameters: Dict = {}): +def build_payload(data: Union[Text, Dict], parameters: Optional[Dict] = None): from aixplain.factories import FileFactory + if parameters is None: + parameters = {} + data = FileFactory.to_link(data) if isinstance(data, dict): payload = data diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..a03eea30 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,4 @@ +from dotenv import load_dotenv + +# Load environment variables once for all tests +load_dotenv() diff --git a/tests/unit/llm_test.py b/tests/unit/llm_test.py index 1329e136..5db8d9de 100644 --- a/tests/unit/llm_test.py +++ b/tests/unit/llm_test.py @@ -91,3 +91,70 @@ def test_run_sync(): assert response.used_credits == 0 assert response.run_time == 0 assert response.usage is None + + +@pytest.mark.skip(reason="Need to fix model response") +def test_run_sync_polling_error(): + """Test handling of polling errors in the run method""" + model_id = "test-model-id" + base_url = config.MODELS_RUN_URL + execute_url = f"{base_url}/{model_id}".replace("/api/v1/execute", "/api/v2/execute") + + ref_response = { + "status": "IN_PROGRESS", + "data": "https://models.aixplain.com/api/v1/data/invalid-id", + } + + with requests_mock.Mocker() as mock: + # Mock the initial execution call + mock.post(execute_url, json=ref_response) + + # Mock the polling URL to raise an exception + poll_url = ref_response["data"] + mock.get(poll_url, exc=Exception("Polling failed")) + + test_model = LLM(id=model_id, name="Test Model", function=Function.TEXT_GENERATION, url=base_url) + + response = test_model.run(data="test input") + + # Updated assertions to match ModelResponse structure + assert isinstance(response, ModelResponse) + assert response.status == ModelStatus.FAILED + assert response.completed is False + assert "No response from the service" in response.error_message + assert response.data == "" + assert response.used_credits == 0 + assert response.run_time == 0 + assert response.usage is None + + +def test_run_with_custom_parameters(): + """Test run method with custom parameters""" + model_id = "test-model-id" + base_url = config.MODELS_RUN_URL + execute_url = f"{base_url}/{model_id}".replace("/api/v1/execute", "/api/v2/execute") + + ref_response = { + "completed": True, + "status": "SUCCESS", + "data": "Test Result", + "usedCredits": 10, + "runTime": 1.5, + "usage": {"prompt_tokens": 10, "completion_tokens": 20}, + } + + with requests_mock.Mocker() as mock: + mock.post(execute_url, json=ref_response) + + test_model = LLM(id=model_id, name="Test Model", function=Function.TEXT_GENERATION, url=base_url) + + custom_params = {"custom_param": "value", "temperature": 0.8} # This should override the default + + response = test_model.run(data="test input", temperature=0.5, parameters=custom_params) + + assert isinstance(response, ModelResponse) + assert response.status == ModelStatus.SUCCESS + assert response.data == "Test Result" + assert response.used_credits == 10 + assert response.run_time == 1.5 + assert response.usage == {"prompt_tokens": 10, "completion_tokens": 20} diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index 94e2f6c2..33f436bb 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -16,10 +16,8 @@ limitations under the License. """ -from dotenv import load_dotenv import requests_mock -load_dotenv() import json from aixplain.utils import config from aixplain.modules import Model @@ -31,6 +29,7 @@ from aixplain.modules.model.response import ModelResponse import pytest from unittest.mock import patch +from aixplain.enums.asset_status import AssetStatus def test_build_payload(): @@ -256,3 +255,234 @@ def test_sync_poll(): assert response["completed"] is True assert response["details"] == {"test": "test"} assert response["data"] == "Polling successful result" + + +def test_run_with_parameters(): + model_id = "test-model-id" + base_url = config.MODELS_RUN_URL + execute_url = f"{base_url}/{model_id}".replace("/api/v1/execute", "/api/v2/execute") + + input_data = "test input" + parameters = {"temperature": 0.7, "max_tokens": 100} + expected_payload = json.dumps({"data": input_data, **parameters}) + + ref_response = { + "completed": True, + "status": "SUCCESS", + "data": "Test Model Result", + "usedCredits": 0, + "runTime": 0, + } + + with requests_mock.Mocker() as mock: + mock.post(execute_url, json=ref_response) + + test_model = Model(id=model_id, name="Test Model", url=base_url, api_key=config.TEAM_API_KEY) + response = test_model.run(data=input_data, parameters=parameters) + + # Verify the payload was constructed correctly + assert mock.last_request.text == expected_payload + assert isinstance(response, ModelResponse) + assert response.status == ModelStatus.SUCCESS + assert response.data == "Test Model Result" + + +def test_run_async_with_parameters(): + model_id = "test-model-id" + base_url = config.MODELS_RUN_URL + execute_url = f"{base_url}/{model_id}" + + input_data = "test input" + parameters = {"temperature": 0.7, "max_tokens": 100} + expected_payload = json.dumps({"data": input_data, **parameters}) + + ref_response = { + "completed": False, + "status": "IN_PROGRESS", + "data": "https://models.aixplain.com/api/v1/data/test-id", + "url": "https://models.aixplain.com/api/v1/data/test-id", + } + + with requests_mock.Mocker() as mock: + mock.post(execute_url, json=ref_response) + + test_model = Model(id=model_id, name="Test Model", url=base_url, api_key=config.TEAM_API_KEY) + response = test_model.run_async(data=input_data, parameters=parameters) + + # Verify the payload was constructed correctly + assert mock.last_request.text == expected_payload + assert isinstance(response, ModelResponse) + assert response.status == "IN_PROGRESS" + assert response.url == ref_response["url"] + + +def test_successful_delete(): + with requests_mock.Mocker() as mock: + model_id = "test-model-id" + url = urljoin(config.BACKEND_URL, f"sdk/models/{model_id}") + headers = {"Authorization": "Token " + config.TEAM_API_KEY, "Content-Type": "application/json"} + + # Mock successful deletion + mock.delete(url, status_code=200) + + test_model = Model(id=model_id, name="Test Model") + test_model.delete() # Should not raise any exception + + # Verify the request was made with correct headers + assert mock.last_request.headers["Authorization"] == headers["Authorization"] + assert mock.last_request.headers["Content-Type"] == headers["Content-Type"] + + +def test_failed_delete(): + with requests_mock.Mocker() as mock: + model_id = "test-model-id" + url = urljoin(config.BACKEND_URL, f"sdk/models/{model_id}") + + # Mock failed deletion + mock.delete(url, status_code=404) + + test_model = Model(id=model_id, name="Test Model") + + with pytest.raises(Exception) as excinfo: + test_model.delete() + + assert "Model Deletion Error: Make sure the model exists and you are the owner." in str(excinfo.value) + + +def test_model_to_dict(): + # Test with regular additional info + model = Model(id="test-id", name="Test Model", description="", additional_info={"key1": "value1", "key2": None}) + result = model.to_dict() + + # Basic assertions + assert result["id"] == "test-id" + assert result["name"] == "Test Model" + assert result["description"] == "" + + # The additional_info is directly in the result + assert result["additional_info"] == {"additional_info": {"key1": "value1", "key2": None}} + + +def test_model_repr(): + # Test with supplier as dict + model1 = Model(id="test-id", name="Test Model", supplier={"name": "Test Supplier"}) + assert repr(model1) == "" + + # Test with supplier as string + model2 = Model(id="test-id", name="Test Model", supplier="Test Supplier") + assert str(model2) == "" + + +def test_poll_with_error(): + with requests_mock.Mocker() as mock: + poll_url = "https://models.aixplain.com/api/v1/data/test-id" + headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} + + # Mock a response that will cause a JSON decode error + mock.get(poll_url, headers=headers, text="Invalid JSON") + + model = Model(id="test-id", name="Test Model") + response = model.poll(poll_url=poll_url) + + assert isinstance(response, ModelResponse) + assert response.status == ModelStatus.FAILED + assert "Expecting value: line 1 column 1" in response.error_message + + +def test_sync_poll_with_timeout(): + poll_url = "https://models.aixplain.com/api/v1/data/test-id" + model = Model(id="test-id", name="Test Model") + + # Mock poll method to always return not completed + with patch.object(model, "poll") as mock_poll: + mock_poll.return_value = {"status": "IN_PROGRESS", "completed": False, "error_message": ""} + + # Test with very short timeout + response = model.sync_poll(poll_url=poll_url, timeout=0.1, wait_time=0.2) + + assert response["status"] == "FAILED" + assert response["completed"] is False + + +def test_check_finetune_status_error(): + with requests_mock.Mocker() as mock: + model_id = "test-id" + url = urljoin(config.BACKEND_URL, f"sdk/finetune/{model_id}/ml-logs") + headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} + + # Mock error response + error_response = {"statusCode": 404, "message": "Finetune not found"} + mock.get(url, headers=headers, json=error_response, status_code=404) + + model = Model(id=model_id, name="Test Model") + status = model.check_finetune_status() + + assert status is None + + +def test_check_finetune_status_with_logs(): + with requests_mock.Mocker() as mock: + model_id = "test-id" + url = urljoin(config.BACKEND_URL, f"sdk/finetune/{model_id}/ml-logs") + + # Mock successful response with logs using valid ModelStatus values + success_response = { + "finetuneStatus": AssetStatus.COMPLETED.value, + "modelStatus": AssetStatus.COMPLETED.value, + "logs": [{"epoch": 1.0, "trainLoss": 0.5, "evalLoss": 0.4}, {"epoch": 2.0, "trainLoss": 0.3, "evalLoss": 0.2}], + } + mock.get(url, json=success_response) + + model = Model(id=model_id, name="Test Model", description="") + + # Test with after_epoch + status = model.check_finetune_status(after_epoch=0) + assert status is not None + assert status.epoch == 1.0 + assert status.training_loss == 0.5 + assert status.validation_loss == 0.4 + + # Test without after_epoch + status = model.check_finetune_status() + assert status is not None + assert status.epoch == 2.0 + assert status.training_loss == 0.3 + assert status.validation_loss == 0.2 + + +def test_check_finetune_status_partial_logs(): + with requests_mock.Mocker() as mock: + model_id = "test-id" + url = urljoin(config.BACKEND_URL, f"sdk/finetune/{model_id}/ml-logs") + + response = { + "finetuneStatus": AssetStatus.IN_PROGRESS.value, + "modelStatus": AssetStatus.IN_PROGRESS.value, + "logs": [{"epoch": 1.0, "trainLoss": 0.5, "evalLoss": 0.4}, {"epoch": 2.0, "trainLoss": 0.3, "evalLoss": 0.2}], + } + mock.get(url, json=response) + + model = Model(id=model_id, name="Test Model", description="") + status = model.check_finetune_status() + + assert status is not None + assert status.epoch == 2.0 + assert status.training_loss == 0.3 + assert status.validation_loss == 0.2 + + +def test_check_finetune_status_no_logs(): + with requests_mock.Mocker() as mock: + model_id = "test-id" + url = urljoin(config.BACKEND_URL, f"sdk/finetune/{model_id}/ml-logs") + + response = {"finetuneStatus": AssetStatus.IN_PROGRESS.value, "modelStatus": AssetStatus.IN_PROGRESS.value, "logs": []} + mock.get(url, json=response) + + model = Model(id=model_id, name="Test Model", description="") + status = model.check_finetune_status() + + assert status is not None + assert status.epoch is None + assert status.training_loss is None + assert status.validation_loss is None