diff --git a/aixplain/factories/model_factory.py b/aixplain/factories/model_factory.py index c11d837a..d82bdd63 100644 --- a/aixplain/factories/model_factory.py +++ b/aixplain/factories/model_factory.py @@ -30,6 +30,8 @@ from aixplain.utils.file_utils import _request_with_retry from urllib.parse import urljoin from warnings import warn +from aixplain.enums.function import FunctionInputOutput +from datetime import datetime class ModelFactory: @@ -66,6 +68,15 @@ def _create_model_from_response(cls, response: Dict) -> Model: if function == Function.TEXT_GENERATION: ModelClass = LLM + created_at = None + if "createdAt" in response and response["createdAt"]: + created_at = datetime.fromisoformat(response["createdAt"].replace("Z", "+00:00")) + function_id = response["function"]["id"] + function = Function(function_id) + function_io = FunctionInputOutput.get(function_id, None) + input_params = {param["code"]: param for param in function_io["spec"]["params"]} + output_params = {param["code"]: param for param in function_io["spec"]["output"]} + return ModelClass( response["id"], response["name"], @@ -73,7 +84,10 @@ def _create_model_from_response(cls, response: Dict) -> Model: api_key=response["api_key"], cost=response["pricing"], function=function, + created_at=created_at, parameters=parameters, + input_params=input_params, + output_params=output_params, is_subscribed=True if "subscription" in response else False, version=response["version"]["id"], ) @@ -270,7 +284,7 @@ def list_host_machines(cls, api_key: Optional[Text] = None) -> List[Dict]: for dictionary in response_dicts: del dictionary["id"] return response_dicts - + @classmethod def list_gpus(cls, api_key: Optional[Text] = None) -> List[List[Text]]: """List GPU names on which you can host your language model. @@ -335,7 +349,7 @@ def create_asset_repo( input_modality: Text, output_modality: Text, documentation_url: Optional[Text] = "", - api_key: Optional[Text] = None + api_key: Optional[Text] = None, ) -> Dict: """Creates an image repository for this model and registers it in the platform backend. @@ -362,7 +376,7 @@ def create_asset_repo( function_id = function_dict["id"] if function_id is None: raise Exception(f"Invalid function name {function}") - create_url = urljoin(config.BACKEND_URL, f"sdk/models/onboard") + create_url = urljoin(config.BACKEND_URL, "sdk/models/onboard") logging.debug(f"URL: {create_url}") if api_key: headers = {"x-api-key": f"{api_key}", "Content-Type": "application/json"} @@ -373,19 +387,14 @@ def create_asset_repo( "model": { "name": name, "description": description, - "connectionType": [ - "synchronous" - ], + "connectionType": ["synchronous"], "function": function_id, - "modalities": [ - f"{input_modality}-{output_modality}" - ], + "modalities": [f"{input_modality}-{output_modality}"], "documentationUrl": documentation_url, - "sourceLanguage": source_language + "sourceLanguage": source_language, }, "source": "aixplain-ecr", - "onboardingParams": { - } + "onboardingParams": {}, } logging.debug(f"Body: {str(payload)}") response = _request_with_retry("post", create_url, headers=headers, json=payload) @@ -412,12 +421,18 @@ def asset_repo_login(cls, api_key: Optional[Text] = None) -> Dict: else: headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} response = _request_with_retry("post", login_url, headers=headers) - print(f"Response: {response}") response_dict = json.loads(response.text) return response_dict @classmethod - def onboard_model(cls, model_id: Text, image_tag: Text, image_hash: Text, host_machine: Optional[Text] = "", api_key: Optional[Text] = None) -> Dict: + def onboard_model( + cls, + model_id: Text, + image_tag: Text, + image_hash: Text, + host_machine: Optional[Text] = "", + api_key: Optional[Text] = None, + ) -> Dict: """Onboard a model after its image has been pushed to ECR. Args: @@ -446,7 +461,14 @@ def onboard_model(cls, model_id: Text, image_tag: Text, image_hash: Text, host_m return response @classmethod - def deploy_huggingface_model(cls, name: Text, hf_repo_id: Text, revision: Optional[Text] = "", hf_token: Optional[Text] = "", api_key: Optional[Text] = None) -> Dict: + def deploy_huggingface_model( + cls, + name: Text, + hf_repo_id: Text, + revision: Optional[Text] = "", + hf_token: Optional[Text] = "", + api_key: Optional[Text] = None, + ) -> Dict: """Onboards and deploys a Hugging Face large language model. Args: @@ -477,8 +499,8 @@ def deploy_huggingface_model(cls, name: Text, hf_repo_id: Text, revision: Option "hf_supplier": supplier, "hf_model_name": model_name, "hf_token": hf_token, - "revision": revision - } + "revision": revision, + }, } response = _request_with_retry("post", deploy_url, headers=headers, json=body) logging.debug(response.text) diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 8fcd80d2..2e9445b5 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -30,6 +30,7 @@ from urllib.parse import urljoin from aixplain.utils.file_utils import _request_with_retry from typing import Union, Optional, Text, Dict +from datetime import datetime class Model(Asset): @@ -48,6 +49,8 @@ class Model(Asset): backend_url (str): URL of the backend. pricing (Dict, optional): model price. Defaults to None. **additional_info: Any additional Model info to be saved + input_params (Dict, optional): input parameters for the function. + output_params (Dict, optional): output parameters for the function. """ def __init__( @@ -61,6 +64,9 @@ def __init__( function: Optional[Function] = None, is_subscribed: bool = False, cost: Optional[Dict] = None, + created_at: Optional[datetime] = None, + input_params: Optional[Dict] = None, + output_params: Optional[Dict] = None, **additional_info, ) -> None: """Model Init @@ -84,6 +90,9 @@ def __init__( self.backend_url = config.BACKEND_URL self.function = function self.is_subscribed = is_subscribed + self.created_at = created_at + self.input_params = input_params + self.output_params = output_params def to_dict(self) -> Dict: """Get the model info as a Dictionary @@ -92,7 +101,14 @@ def to_dict(self) -> Dict: Dict: Model Information """ clean_additional_info = {k: v for k, v in self.additional_info.items() if v is not None} - return {"id": self.id, "name": self.name, "supplier": self.supplier, "additional_info": clean_additional_info} + return { + "id": self.id, + "name": self.name, + "supplier": self.supplier, + "additional_info": clean_additional_info, + "input_params": self.input_params, + "output_params": self.output_params, + } def __repr__(self): try: @@ -257,7 +273,9 @@ def run_async(self, data: Union[Text, Dict], name: Text = "model_process", param error = "Validation-related error: Please ensure all required fields are provided and correctly formatted." else: status_code = str(r.status_code) - error = f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." + error = ( + f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." + ) response = {"status": "FAILED", "error_message": error} logging.error(f"Error in request for {name} - {r.status_code}: {error}") except Exception: diff --git a/pyproject.toml b/pyproject.toml index 5b0ded4b..be397bdd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ namespaces = true [project] name = "aiXplain" -version = "0.2.13rc2" +version = "0.2.18" description = "aiXplain SDK adds AI functions to software." readme = "README.md" requires-python = ">=3.5, <4" @@ -49,7 +49,7 @@ dependencies = [ "python-dotenv>=1.0.0", "validators>=0.20.0", "filetype>=1.2.0", - "click>=7.1.2,<8.0.0", + "click>=7.1.2", "PyYAML>=6.0.1", "dataclasses-json>=0.5.2", "Jinja2==3.1.4", diff --git a/tests/functional/finetune/finetune_functional_test.py b/tests/functional/finetune/finetune_functional_test.py index 7b45613c..46520137 100644 --- a/tests/functional/finetune/finetune_functional_test.py +++ b/tests/functional/finetune/finetune_functional_test.py @@ -1,5 +1,4 @@ __author__ = "lucaspavanelli" - """ Copyright 2022 The aiXplain SDK authors @@ -26,6 +25,7 @@ from aixplain.factories import FinetuneFactory from aixplain.modules.finetune.cost import FinetuneCost from aixplain.enums import Function, Language +from datetime import datetime, timedelta, timezone import pytest @@ -40,11 +40,6 @@ def read_data(data_path): return json.load(open(data_path, "r")) -@pytest.fixture(scope="module", params=read_data(RUN_FILE)) -def run_input_map(request): - return request.param - - @pytest.fixture(scope="module", params=read_data(ESTIMATE_COST_FILE)) def estimate_cost_input_map(request): return request.param @@ -60,11 +55,32 @@ def validate_prompt_input_map(request): return request.param -def test_end2end(run_input_map): - model = ModelFactory.get(run_input_map["model_id"]) - dataset_list = [DatasetFactory.list(query=run_input_map["dataset_name"])["results"][0]] +def pytest_generate_tests(metafunc): + if "input_map" in metafunc.fixturenames: + four_weeks_ago = datetime.now(timezone.utc) - timedelta(weeks=4) + models = ModelFactory.list(function=Function.TEXT_GENERATION, is_finetunable=True)["results"] + + recent_models = [ + { + "model_name": model.name, + "model_id": model.id, + "dataset_name": "Test text generation dataset", + "inference_data": "Hello!", + "required_dev": True, + "search_metadata": False, + } + for model in models + if model.created_at is not None and model.created_at >= four_weeks_ago + ] + recent_models += read_data(RUN_FILE) + metafunc.parametrize("input_map", recent_models) + + +def test_end2end(input_map): + model = input_map["model_id"] + dataset_list = [DatasetFactory.list(query=input_map["dataset_name"])["results"][0]] train_percentage, dev_percentage = 100, 0 - if run_input_map["required_dev"]: + if input_map["required_dev"]: train_percentage, dev_percentage = 80, 20 finetune = FinetuneFactory.create( str(uuid.uuid4()), dataset_list, model, train_percentage=train_percentage, dev_percentage=dev_percentage @@ -85,12 +101,12 @@ def test_end2end(run_input_map): assert finetune_model.check_finetune_status().model_status.value == "onboarded" time.sleep(30) print(f"Model dict: {finetune_model.__dict__}") - result = finetune_model.run(run_input_map["inference_data"]) + result = finetune_model.run(input_map["inference_data"]) print(f"Result: {result}") assert result is not None - if run_input_map["search_metadata"]: + if input_map["search_metadata"]: assert "details" in result - assert len(result["details"]) > 0 + assert len(result["details"]) > 0 assert "metadata" in result["details"][0] assert len(result["details"][0]["metadata"]) > 0 finetune_model.delete() diff --git a/tests/functional/general_assets/asset_functional_test.py b/tests/functional/general_assets/asset_functional_test.py index d35a4d9a..b0d8f6ef 100644 --- a/tests/functional/general_assets/asset_functional_test.py +++ b/tests/functional/general_assets/asset_functional_test.py @@ -112,3 +112,25 @@ def test_llm_instantiation(): """Test that the LLM model is correctly instantiated.""" models = ModelFactory.list(function=Function.TEXT_GENERATION)["results"] assert isinstance(models[0], LLM) + + +def test_model_io(): + model_id = "64aee5824d34b1221e70ac07" + model = ModelFactory.get(model_id) + + expected_input = { + "text": { + "name": "Text Prompt", + "code": "text", + "required": True, + "isFixed": False, + "dataType": "text", + "dataSubType": "text", + "multipleValues": False, + "defaultValues": [], + } + } + expected_output = {"data": {"name": "Generated Image", "code": "data", "defaultValue": [], "dataType": "image"}} + + assert model.input_params == expected_input + assert model.output_params == expected_output diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index cd6f7a5a..c52bb950 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -63,24 +63,22 @@ def test_failed_poll(): @pytest.mark.parametrize( "status_code,error_message", [ - (401,"Unauthorized API key: Please verify the spelling of the API key and its current validity."), - (465,"Subscription-related error: Please ensure that your subscription is active and has not expired."), - (475,"Billing-related error: Please ensure you have enough credits to run this model. "), + (401, "Unauthorized API key: Please verify the spelling of the API key and its current validity."), + (465, "Subscription-related error: Please ensure that your subscription is active and has not expired."), + (475, "Billing-related error: Please ensure you have enough credits to run this model. "), (485, "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access."), (495, "Validation-related error: Please ensure all required fields are provided and correctly formatted."), (501, "Status 501: Unspecified error: An unspecified error occurred while processing your request."), - ], ) - def test_run_async_errors(status_code, error_message): base_url = config.MODELS_RUN_URL model_id = "model-id" execute_url = urljoin(base_url, f"execute/{model_id}") - + with requests_mock.Mocker() as mock: mock.post(execute_url, status_code=status_code) - test_model = Model(id=model_id, name="Test Model",url=base_url) + test_model = Model(id=model_id, name="Test Model", url=base_url) response = test_model.run_async(data="input_data") assert response["status"] == "FAILED" - assert response["error_message"] == error_message \ No newline at end of file + assert response["error_message"] == error_message