diff --git a/aixplain/factories/model_factory.py b/aixplain/factories/model_factory.py index c11d837a..da44600c 100644 --- a/aixplain/factories/model_factory.py +++ b/aixplain/factories/model_factory.py @@ -30,6 +30,7 @@ from aixplain.utils.file_utils import _request_with_retry from urllib.parse import urljoin from warnings import warn +from aixplain.enums.function import FunctionInputOutput class ModelFactory: @@ -66,6 +67,12 @@ def _create_model_from_response(cls, response: Dict) -> Model: if function == Function.TEXT_GENERATION: ModelClass = LLM + function_id = response["function"]["id"] + function = Function(function_id) + function_io = FunctionInputOutput.get(function_id, None) + input_params = {param["code"]: param for param in function_io["spec"]["params"]} + output_params = {param["code"]: param for param in function_io["spec"]["output"]} + return ModelClass( response["id"], response["name"], @@ -74,6 +81,8 @@ def _create_model_from_response(cls, response: Dict) -> Model: cost=response["pricing"], function=function, parameters=parameters, + input_params=input_params, + output_params=output_params, is_subscribed=True if "subscription" in response else False, version=response["version"]["id"], ) @@ -270,7 +279,7 @@ def list_host_machines(cls, api_key: Optional[Text] = None) -> List[Dict]: for dictionary in response_dicts: del dictionary["id"] return response_dicts - + @classmethod def list_gpus(cls, api_key: Optional[Text] = None) -> List[List[Text]]: """List GPU names on which you can host your language model. @@ -335,7 +344,7 @@ def create_asset_repo( input_modality: Text, output_modality: Text, documentation_url: Optional[Text] = "", - api_key: Optional[Text] = None + api_key: Optional[Text] = None, ) -> Dict: """Creates an image repository for this model and registers it in the platform backend. @@ -362,7 +371,7 @@ def create_asset_repo( function_id = function_dict["id"] if function_id is None: raise Exception(f"Invalid function name {function}") - create_url = urljoin(config.BACKEND_URL, f"sdk/models/onboard") + create_url = urljoin(config.BACKEND_URL, "sdk/models/onboard") logging.debug(f"URL: {create_url}") if api_key: headers = {"x-api-key": f"{api_key}", "Content-Type": "application/json"} @@ -373,19 +382,14 @@ def create_asset_repo( "model": { "name": name, "description": description, - "connectionType": [ - "synchronous" - ], + "connectionType": ["synchronous"], "function": function_id, - "modalities": [ - f"{input_modality}-{output_modality}" - ], + "modalities": [f"{input_modality}-{output_modality}"], "documentationUrl": documentation_url, - "sourceLanguage": source_language + "sourceLanguage": source_language, }, "source": "aixplain-ecr", - "onboardingParams": { - } + "onboardingParams": {}, } logging.debug(f"Body: {str(payload)}") response = _request_with_retry("post", create_url, headers=headers, json=payload) @@ -412,12 +416,18 @@ def asset_repo_login(cls, api_key: Optional[Text] = None) -> Dict: else: headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} response = _request_with_retry("post", login_url, headers=headers) - print(f"Response: {response}") response_dict = json.loads(response.text) return response_dict @classmethod - def onboard_model(cls, model_id: Text, image_tag: Text, image_hash: Text, host_machine: Optional[Text] = "", api_key: Optional[Text] = None) -> Dict: + def onboard_model( + cls, + model_id: Text, + image_tag: Text, + image_hash: Text, + host_machine: Optional[Text] = "", + api_key: Optional[Text] = None, + ) -> Dict: """Onboard a model after its image has been pushed to ECR. Args: @@ -446,7 +456,14 @@ def onboard_model(cls, model_id: Text, image_tag: Text, image_hash: Text, host_m return response @classmethod - def deploy_huggingface_model(cls, name: Text, hf_repo_id: Text, revision: Optional[Text] = "", hf_token: Optional[Text] = "", api_key: Optional[Text] = None) -> Dict: + def deploy_huggingface_model( + cls, + name: Text, + hf_repo_id: Text, + revision: Optional[Text] = "", + hf_token: Optional[Text] = "", + api_key: Optional[Text] = None, + ) -> Dict: """Onboards and deploys a Hugging Face large language model. Args: @@ -477,8 +494,8 @@ def deploy_huggingface_model(cls, name: Text, hf_repo_id: Text, revision: Option "hf_supplier": supplier, "hf_model_name": model_name, "hf_token": hf_token, - "revision": revision - } + "revision": revision, + }, } response = _request_with_retry("post", deploy_url, headers=headers, json=body) logging.debug(response.text) diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 8fcd80d2..e18f1896 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -48,6 +48,8 @@ class Model(Asset): backend_url (str): URL of the backend. pricing (Dict, optional): model price. Defaults to None. **additional_info: Any additional Model info to be saved + input_params (Dict, optional): input parameters for the function. + output_params (Dict, optional): output parameters for the function. """ def __init__( @@ -61,6 +63,8 @@ def __init__( function: Optional[Function] = None, is_subscribed: bool = False, cost: Optional[Dict] = None, + input_params: Optional[Dict] = None, + output_params: Optional[Dict] = None, **additional_info, ) -> None: """Model Init @@ -84,6 +88,8 @@ def __init__( self.backend_url = config.BACKEND_URL self.function = function self.is_subscribed = is_subscribed + self.input_params = input_params + self.output_params = output_params def to_dict(self) -> Dict: """Get the model info as a Dictionary @@ -92,7 +98,7 @@ def to_dict(self) -> Dict: Dict: Model Information """ clean_additional_info = {k: v for k, v in self.additional_info.items() if v is not None} - return {"id": self.id, "name": self.name, "supplier": self.supplier, "additional_info": clean_additional_info} + return {"id": self.id, "name": self.name, "supplier": self.supplier, "additional_info": clean_additional_info, "input_params": self.input_params,"output_params": self.output_params,} def __repr__(self): try: diff --git a/tests/functional/general_assets/asset_functional_test.py b/tests/functional/general_assets/asset_functional_test.py index d35a4d9a..b0d8f6ef 100644 --- a/tests/functional/general_assets/asset_functional_test.py +++ b/tests/functional/general_assets/asset_functional_test.py @@ -112,3 +112,25 @@ def test_llm_instantiation(): """Test that the LLM model is correctly instantiated.""" models = ModelFactory.list(function=Function.TEXT_GENERATION)["results"] assert isinstance(models[0], LLM) + + +def test_model_io(): + model_id = "64aee5824d34b1221e70ac07" + model = ModelFactory.get(model_id) + + expected_input = { + "text": { + "name": "Text Prompt", + "code": "text", + "required": True, + "isFixed": False, + "dataType": "text", + "dataSubType": "text", + "multipleValues": False, + "defaultValues": [], + } + } + expected_output = {"data": {"name": "Generated Image", "code": "data", "defaultValue": [], "dataType": "image"}} + + assert model.input_params == expected_input + assert model.output_params == expected_output diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index cd6f7a5a..c52bb950 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -63,24 +63,22 @@ def test_failed_poll(): @pytest.mark.parametrize( "status_code,error_message", [ - (401,"Unauthorized API key: Please verify the spelling of the API key and its current validity."), - (465,"Subscription-related error: Please ensure that your subscription is active and has not expired."), - (475,"Billing-related error: Please ensure you have enough credits to run this model. "), + (401, "Unauthorized API key: Please verify the spelling of the API key and its current validity."), + (465, "Subscription-related error: Please ensure that your subscription is active and has not expired."), + (475, "Billing-related error: Please ensure you have enough credits to run this model. "), (485, "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access."), (495, "Validation-related error: Please ensure all required fields are provided and correctly formatted."), (501, "Status 501: Unspecified error: An unspecified error occurred while processing your request."), - ], ) - def test_run_async_errors(status_code, error_message): base_url = config.MODELS_RUN_URL model_id = "model-id" execute_url = urljoin(base_url, f"execute/{model_id}") - + with requests_mock.Mocker() as mock: mock.post(execute_url, status_code=status_code) - test_model = Model(id=model_id, name="Test Model",url=base_url) + test_model = Model(id=model_id, name="Test Model", url=base_url) response = test_model.run_async(data="input_data") assert response["status"] == "FAILED" - assert response["error_message"] == error_message \ No newline at end of file + assert response["error_message"] == error_message