From c42afe57051c420345c0af8367f1bfec85960028 Mon Sep 17 00:00:00 2001 From: xainaz Date: Fri, 30 Aug 2024 22:31:24 +0300 Subject: [PATCH 1/5] Added input and output attributes to model --- aixplain/factories/model_factory.py | 8 +++++++- aixplain/modules/model/__init__.py | 6 +++++- tests/unit/model_test.py | 13 ++++++++++++- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/aixplain/factories/model_factory.py b/aixplain/factories/model_factory.py index c11d837a..04ea6d1f 100644 --- a/aixplain/factories/model_factory.py +++ b/aixplain/factories/model_factory.py @@ -30,7 +30,7 @@ from aixplain.utils.file_utils import _request_with_retry from urllib.parse import urljoin from warnings import warn - +from aixplain.enums.function import load_functions class ModelFactory: """A static class for creating and exploring Model Objects. @@ -66,6 +66,10 @@ def _create_model_from_response(cls, response: Dict) -> Model: if function == Function.TEXT_GENERATION: ModelClass = LLM + _, functionio= load_functions() + input_data= functionio[function]["input"] + output_data= functionio[function]["output"] + return ModelClass( response["id"], response["name"], @@ -74,6 +78,8 @@ def _create_model_from_response(cls, response: Dict) -> Model: cost=response["pricing"], function=function, parameters=parameters, + input_params=input_data, + output_params=output_data, is_subscribed=True if "subscription" in response else False, version=response["version"]["id"], ) diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 8fcd80d2..4328bfd3 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -59,6 +59,8 @@ def __init__( supplier: Union[Dict, Text, Supplier, int] = "aiXplain", version: Optional[Text] = None, function: Optional[Function] = None, + input_params: Dict = {}, + output_params: Dict = {}, is_subscribed: bool = False, cost: Optional[Dict] = None, **additional_info, @@ -83,6 +85,8 @@ def __init__( self.url = config.MODELS_RUN_URL self.backend_url = config.BACKEND_URL self.function = function + self.input_params = input_params + self.output_params = output_params self.is_subscribed = is_subscribed def to_dict(self) -> Dict: @@ -92,7 +96,7 @@ def to_dict(self) -> Dict: Dict: Model Information """ clean_additional_info = {k: v for k, v in self.additional_info.items() if v is not None} - return {"id": self.id, "name": self.name, "supplier": self.supplier, "additional_info": clean_additional_info} + return {"id": self.id, "name": self.name, "supplier": self.supplier,"input_params": self.input_params,"output_params": self.output_params, "additional_info": clean_additional_info} def __repr__(self): try: diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index cd6f7a5a..f0037f53 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -24,6 +24,7 @@ import re from aixplain.utils import config from aixplain.modules import Model +from aixplain.factories import ModelFactory import pytest @@ -83,4 +84,14 @@ def test_run_async_errors(status_code, error_message): test_model = Model(id=model_id, name="Test Model",url=base_url) response = test_model.run_async(data="input_data") assert response["status"] == "FAILED" - assert response["error_message"] == error_message \ No newline at end of file + assert response["error_message"] == error_message + +def test_model_io(): + model_id = "64aee5824d34b1221e70ac07" + model = ModelFactory.get(model_id) + + expected_input = {"text"} + expected_output = {"image"} + + assert model.input_params == expected_input + assert model.output_params == expected_output \ No newline at end of file From 34edd7903198eaa4167a817146a69df34a74c3c6 Mon Sep 17 00:00:00 2001 From: xainaz Date: Mon, 2 Sep 2024 22:44:52 +0300 Subject: [PATCH 2/5] Added correct test --- tests/unit/model_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index f0037f53..298ac603 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -19,12 +19,12 @@ from dotenv import load_dotenv from urllib.parse import urljoin import requests_mock +from aixplain.factories import ModelFactory load_dotenv() import re from aixplain.utils import config from aixplain.modules import Model -from aixplain.factories import ModelFactory import pytest From 322c13d53b087d6a453f023be4abac237d5c4361 Mon Sep 17 00:00:00 2001 From: xainaz Date: Mon, 2 Sep 2024 22:49:20 +0300 Subject: [PATCH 3/5] Fixed model class --- aixplain/modules/model/__init__.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 4328bfd3..e18f1896 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -48,6 +48,8 @@ class Model(Asset): backend_url (str): URL of the backend. pricing (Dict, optional): model price. Defaults to None. **additional_info: Any additional Model info to be saved + input_params (Dict, optional): input parameters for the function. + output_params (Dict, optional): output parameters for the function. """ def __init__( @@ -59,10 +61,10 @@ def __init__( supplier: Union[Dict, Text, Supplier, int] = "aiXplain", version: Optional[Text] = None, function: Optional[Function] = None, - input_params: Dict = {}, - output_params: Dict = {}, is_subscribed: bool = False, cost: Optional[Dict] = None, + input_params: Optional[Dict] = None, + output_params: Optional[Dict] = None, **additional_info, ) -> None: """Model Init @@ -85,9 +87,9 @@ def __init__( self.url = config.MODELS_RUN_URL self.backend_url = config.BACKEND_URL self.function = function - self.input_params = input_params - self.output_params = output_params self.is_subscribed = is_subscribed + self.input_params = input_params + self.output_params = output_params def to_dict(self) -> Dict: """Get the model info as a Dictionary @@ -96,7 +98,7 @@ def to_dict(self) -> Dict: Dict: Model Information """ clean_additional_info = {k: v for k, v in self.additional_info.items() if v is not None} - return {"id": self.id, "name": self.name, "supplier": self.supplier,"input_params": self.input_params,"output_params": self.output_params, "additional_info": clean_additional_info} + return {"id": self.id, "name": self.name, "supplier": self.supplier, "additional_info": clean_additional_info, "input_params": self.input_params,"output_params": self.output_params,} def __repr__(self): try: From c00f869e549d645966c4d987a01a5ca93146a224 Mon Sep 17 00:00:00 2001 From: xainaz Date: Mon, 2 Sep 2024 22:50:22 +0300 Subject: [PATCH 4/5] Fixed model factory --- aixplain/factories/model_factory.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/aixplain/factories/model_factory.py b/aixplain/factories/model_factory.py index 04ea6d1f..4fc844aa 100644 --- a/aixplain/factories/model_factory.py +++ b/aixplain/factories/model_factory.py @@ -30,7 +30,7 @@ from aixplain.utils.file_utils import _request_with_retry from urllib.parse import urljoin from warnings import warn -from aixplain.enums.function import load_functions +from aixplain.enums.function import FunctionInputOutput class ModelFactory: """A static class for creating and exploring Model Objects. @@ -65,11 +65,13 @@ def _create_model_from_response(cls, response: Dict) -> Model: ModelClass = Model if function == Function.TEXT_GENERATION: ModelClass = LLM - - _, functionio= load_functions() - input_data= functionio[function]["input"] - output_data= functionio[function]["output"] + function_id = response["function"]["id"] + function = Function(function_id) + function_io = FunctionInputOutput.get(function_id, None) + input_params= function_io['input'] + output_params=function_io['output'] + return ModelClass( response["id"], response["name"], @@ -78,8 +80,8 @@ def _create_model_from_response(cls, response: Dict) -> Model: cost=response["pricing"], function=function, parameters=parameters, - input_params=input_data, - output_params=output_data, + input_params=input_params, + output_params=output_params, is_subscribed=True if "subscription" in response else False, version=response["version"]["id"], ) @@ -418,7 +420,6 @@ def asset_repo_login(cls, api_key: Optional[Text] = None) -> Dict: else: headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} response = _request_with_retry("post", login_url, headers=headers) - print(f"Response: {response}") response_dict = json.loads(response.text) return response_dict From 5f785832ed4ffab6a7dc63370a7f6c059bcb2e53 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Mon, 2 Sep 2024 18:03:28 -0300 Subject: [PATCH 5/5] Getting the parameters from right source and add functional test --- aixplain/factories/model_factory.py | 50 +++++++++++-------- .../general_assets/asset_functional_test.py | 22 ++++++++ tests/unit/model_test.py | 23 ++------- 3 files changed, 57 insertions(+), 38 deletions(-) diff --git a/aixplain/factories/model_factory.py b/aixplain/factories/model_factory.py index 4fc844aa..da44600c 100644 --- a/aixplain/factories/model_factory.py +++ b/aixplain/factories/model_factory.py @@ -32,6 +32,7 @@ from warnings import warn from aixplain.enums.function import FunctionInputOutput + class ModelFactory: """A static class for creating and exploring Model Objects. @@ -65,12 +66,12 @@ def _create_model_from_response(cls, response: Dict) -> Model: ModelClass = Model if function == Function.TEXT_GENERATION: ModelClass = LLM - + function_id = response["function"]["id"] function = Function(function_id) function_io = FunctionInputOutput.get(function_id, None) - input_params= function_io['input'] - output_params=function_io['output'] + input_params = {param["code"]: param for param in function_io["spec"]["params"]} + output_params = {param["code"]: param for param in function_io["spec"]["output"]} return ModelClass( response["id"], @@ -80,7 +81,7 @@ def _create_model_from_response(cls, response: Dict) -> Model: cost=response["pricing"], function=function, parameters=parameters, - input_params=input_params, + input_params=input_params, output_params=output_params, is_subscribed=True if "subscription" in response else False, version=response["version"]["id"], @@ -278,7 +279,7 @@ def list_host_machines(cls, api_key: Optional[Text] = None) -> List[Dict]: for dictionary in response_dicts: del dictionary["id"] return response_dicts - + @classmethod def list_gpus(cls, api_key: Optional[Text] = None) -> List[List[Text]]: """List GPU names on which you can host your language model. @@ -343,7 +344,7 @@ def create_asset_repo( input_modality: Text, output_modality: Text, documentation_url: Optional[Text] = "", - api_key: Optional[Text] = None + api_key: Optional[Text] = None, ) -> Dict: """Creates an image repository for this model and registers it in the platform backend. @@ -370,7 +371,7 @@ def create_asset_repo( function_id = function_dict["id"] if function_id is None: raise Exception(f"Invalid function name {function}") - create_url = urljoin(config.BACKEND_URL, f"sdk/models/onboard") + create_url = urljoin(config.BACKEND_URL, "sdk/models/onboard") logging.debug(f"URL: {create_url}") if api_key: headers = {"x-api-key": f"{api_key}", "Content-Type": "application/json"} @@ -381,19 +382,14 @@ def create_asset_repo( "model": { "name": name, "description": description, - "connectionType": [ - "synchronous" - ], + "connectionType": ["synchronous"], "function": function_id, - "modalities": [ - f"{input_modality}-{output_modality}" - ], + "modalities": [f"{input_modality}-{output_modality}"], "documentationUrl": documentation_url, - "sourceLanguage": source_language + "sourceLanguage": source_language, }, "source": "aixplain-ecr", - "onboardingParams": { - } + "onboardingParams": {}, } logging.debug(f"Body: {str(payload)}") response = _request_with_retry("post", create_url, headers=headers, json=payload) @@ -424,7 +420,14 @@ def asset_repo_login(cls, api_key: Optional[Text] = None) -> Dict: return response_dict @classmethod - def onboard_model(cls, model_id: Text, image_tag: Text, image_hash: Text, host_machine: Optional[Text] = "", api_key: Optional[Text] = None) -> Dict: + def onboard_model( + cls, + model_id: Text, + image_tag: Text, + image_hash: Text, + host_machine: Optional[Text] = "", + api_key: Optional[Text] = None, + ) -> Dict: """Onboard a model after its image has been pushed to ECR. Args: @@ -453,7 +456,14 @@ def onboard_model(cls, model_id: Text, image_tag: Text, image_hash: Text, host_m return response @classmethod - def deploy_huggingface_model(cls, name: Text, hf_repo_id: Text, revision: Optional[Text] = "", hf_token: Optional[Text] = "", api_key: Optional[Text] = None) -> Dict: + def deploy_huggingface_model( + cls, + name: Text, + hf_repo_id: Text, + revision: Optional[Text] = "", + hf_token: Optional[Text] = "", + api_key: Optional[Text] = None, + ) -> Dict: """Onboards and deploys a Hugging Face large language model. Args: @@ -484,8 +494,8 @@ def deploy_huggingface_model(cls, name: Text, hf_repo_id: Text, revision: Option "hf_supplier": supplier, "hf_model_name": model_name, "hf_token": hf_token, - "revision": revision - } + "revision": revision, + }, } response = _request_with_retry("post", deploy_url, headers=headers, json=body) logging.debug(response.text) diff --git a/tests/functional/general_assets/asset_functional_test.py b/tests/functional/general_assets/asset_functional_test.py index d35a4d9a..b0d8f6ef 100644 --- a/tests/functional/general_assets/asset_functional_test.py +++ b/tests/functional/general_assets/asset_functional_test.py @@ -112,3 +112,25 @@ def test_llm_instantiation(): """Test that the LLM model is correctly instantiated.""" models = ModelFactory.list(function=Function.TEXT_GENERATION)["results"] assert isinstance(models[0], LLM) + + +def test_model_io(): + model_id = "64aee5824d34b1221e70ac07" + model = ModelFactory.get(model_id) + + expected_input = { + "text": { + "name": "Text Prompt", + "code": "text", + "required": True, + "isFixed": False, + "dataType": "text", + "dataSubType": "text", + "multipleValues": False, + "defaultValues": [], + } + } + expected_output = {"data": {"name": "Generated Image", "code": "data", "defaultValue": [], "dataType": "image"}} + + assert model.input_params == expected_input + assert model.output_params == expected_output diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index 298ac603..c52bb950 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -19,7 +19,6 @@ from dotenv import load_dotenv from urllib.parse import urljoin import requests_mock -from aixplain.factories import ModelFactory load_dotenv() import re @@ -64,34 +63,22 @@ def test_failed_poll(): @pytest.mark.parametrize( "status_code,error_message", [ - (401,"Unauthorized API key: Please verify the spelling of the API key and its current validity."), - (465,"Subscription-related error: Please ensure that your subscription is active and has not expired."), - (475,"Billing-related error: Please ensure you have enough credits to run this model. "), + (401, "Unauthorized API key: Please verify the spelling of the API key and its current validity."), + (465, "Subscription-related error: Please ensure that your subscription is active and has not expired."), + (475, "Billing-related error: Please ensure you have enough credits to run this model. "), (485, "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access."), (495, "Validation-related error: Please ensure all required fields are provided and correctly formatted."), (501, "Status 501: Unspecified error: An unspecified error occurred while processing your request."), - ], ) - def test_run_async_errors(status_code, error_message): base_url = config.MODELS_RUN_URL model_id = "model-id" execute_url = urljoin(base_url, f"execute/{model_id}") - + with requests_mock.Mocker() as mock: mock.post(execute_url, status_code=status_code) - test_model = Model(id=model_id, name="Test Model",url=base_url) + test_model = Model(id=model_id, name="Test Model", url=base_url) response = test_model.run_async(data="input_data") assert response["status"] == "FAILED" assert response["error_message"] == error_message - -def test_model_io(): - model_id = "64aee5824d34b1221e70ac07" - model = ModelFactory.get(model_id) - - expected_input = {"text"} - expected_output = {"image"} - - assert model.input_params == expected_input - assert model.output_params == expected_output \ No newline at end of file