From 789d2dcb75e2b3d027a702d238982cf90c2b2ba8 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Mon, 25 Nov 2024 17:54:15 -0300 Subject: [PATCH 1/2] Onboard Utilities --- aixplain/enums/data_type.py | 1 + .../__init__.py} | 173 +++++------------ aixplain/factories/model_factory/utils.py | 142 ++++++++++++++ aixplain/modules/model/utility_model.py | 179 ++++++++++++++++++ tests/unit/model_test.py | 4 +- 5 files changed, 368 insertions(+), 131 deletions(-) rename aixplain/factories/{model_factory.py => model_factory/__init__.py} (73%) create mode 100644 aixplain/factories/model_factory/utils.py create mode 100644 aixplain/modules/model/utility_model.py diff --git a/aixplain/enums/data_type.py b/aixplain/enums/data_type.py index 11432bcf..dcae0422 100644 --- a/aixplain/enums/data_type.py +++ b/aixplain/enums/data_type.py @@ -35,6 +35,7 @@ class DataType(str, Enum): VIDEO = "video" EMBEDDING = "embedding" NUMBER = "number" + BOOLEAN = "boolean" def __str__(self): return self._value_ diff --git a/aixplain/factories/model_factory.py b/aixplain/factories/model_factory/__init__.py similarity index 73% rename from aixplain/factories/model_factory.py rename to aixplain/factories/model_factory/__init__.py index b6588023..13db1fb4 100644 --- a/aixplain/factories/model_factory.py +++ b/aixplain/factories/model_factory/__init__.py @@ -24,14 +24,11 @@ import json import logging from aixplain.modules.model import Model -from aixplain.modules.model.llm_model import LLM +from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput from aixplain.enums import Function, Language, OwnershipType, Supplier, SortBy, SortOrder from aixplain.utils import config from aixplain.utils.file_utils import _request_with_retry from urllib.parse import urljoin -from warnings import warn -from aixplain.enums.function import FunctionInputOutput -from datetime import datetime class ModelFactory: @@ -45,53 +42,48 @@ class ModelFactory: backend_url = config.BACKEND_URL @classmethod - def _create_model_from_response(cls, response: Dict) -> Model: - """Converts response Json to 'Model' object + def create_utility_model(cls, name: Text, description: Text, inputs: List[UtilityModelInput], code: Text) -> UtilityModel: + """Create a utility model Args: - response (Dict): Json from API + name (Text): name of the model + description (Text): description of the model + inputs (List[UtilityModelInput]): inputs of the model + code (Text): code of the model Returns: - Model: Coverted 'Model' object + UtilityModel: created utility model """ - if "api_key" not in response: - response["api_key"] = config.TEAM_API_KEY - - parameters = {} - if "params" in response: - for param in response["params"]: - if "language" in param["name"]: - parameters[param["name"]] = [w["value"] for w in param["values"]] - - function = Function(response["function"]["id"]) - ModelClass = Model - if function == Function.TEXT_GENERATION: - ModelClass = LLM - - created_at = None - if "createdAt" in response and response["createdAt"]: - created_at = datetime.fromisoformat(response["createdAt"].replace("Z", "+00:00")) - function_id = response["function"]["id"] - function = Function(function_id) - function_io = FunctionInputOutput.get(function_id, None) - input_params = {param["code"]: param for param in function_io["spec"]["params"]} - output_params = {param["code"]: param for param in function_io["spec"]["output"]} - - return ModelClass( - response["id"], - response["name"], - description=response.get("description", ""), - supplier=response["supplier"], - api_key=response["api_key"], - cost=response["pricing"], - function=function, - created_at=created_at, - parameters=parameters, - input_params=input_params, - output_params=output_params, - is_subscribed=True if "subscription" in response else False, - version=response["version"]["id"], + utility_model = UtilityModel( + id="", + name=name, + description=description, + inputs=inputs, + code=code, + function=Function.UTILITIES, + api_key=config.TEAM_API_KEY, ) + payload = utility_model.to_dict() + url = urljoin(cls.backend_url, "sdk/utilities") + headers = {"x-api-key": f"{config.TEAM_API_KEY}", "Content-Type": "application/json"} + try: + logging.info(f"Start service for POST Utility Model - {url} - {headers} - {payload}") + r = _request_with_retry("post", url, headers=headers, json=payload) + resp = r.json() + except Exception as e: + logging.error(f"Error creating utility model: {e}") + raise e + + if 200 <= r.status_code < 300: + utility_model.id = resp["id"] + logging.info(f"Utility Model Creation: Model {utility_model.id} instantiated.") + return utility_model + else: + error_message = ( + f"Utility Model Creation: Failed to create utility model. Status Code: {r.status_code}. Error: {resp}" + ) + logging.error(error_message) + raise Exception(error_message) @classmethod def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model: @@ -128,7 +120,9 @@ def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model: resp["api_key"] = config.TEAM_API_KEY if api_key is not None: resp["api_key"] = api_key - model = cls._create_model_from_response(resp) + from aixplain.factories.model_factory.utils import create_model_from_response + + model = create_model_from_response(resp) logging.info(f"Model Creation: Model {model_id} instantiated.") return model else: @@ -136,89 +130,6 @@ def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model: logging.error(error_message) raise Exception(error_message) - @classmethod - def create_asset_from_id(cls, model_id: Text) -> Model: - warn( - 'This method will be deprecated in the next versions of the SDK. Use "get" instead.', - DeprecationWarning, - stacklevel=2, - ) - return cls.get(model_id) - - @classmethod - def _get_assets_from_page( - cls, - query, - page_number: int, - page_size: int, - function: Function, - suppliers: Union[Supplier, List[Supplier]], - source_languages: Union[Language, List[Language]], - target_languages: Union[Language, List[Language]], - is_finetunable: bool = None, - ownership: Optional[Tuple[OwnershipType, List[OwnershipType]]] = None, - sort_by: Optional[SortBy] = None, - sort_order: SortOrder = SortOrder.ASCENDING, - ) -> List[Model]: - try: - url = urljoin(cls.backend_url, "sdk/models/paginate") - filter_params = {"q": query, "pageNumber": page_number, "pageSize": page_size} - if is_finetunable is not None: - filter_params["isFineTunable"] = is_finetunable - if function is not None: - filter_params["functions"] = [function.value] - if suppliers is not None: - if isinstance(suppliers, Supplier) is True: - suppliers = [suppliers] - filter_params["suppliers"] = [supplier.value["id"] for supplier in suppliers] - if ownership is not None: - if isinstance(ownership, OwnershipType) is True: - ownership = [ownership] - filter_params["ownership"] = [ownership_.value for ownership_ in ownership] - - lang_filter_params = [] - if source_languages is not None: - if isinstance(source_languages, Language): - source_languages = [source_languages] - if function == Function.TRANSLATION: - lang_filter_params.append({"code": "sourcelanguage", "value": source_languages[0].value["language"]}) - else: - lang_filter_params.append({"code": "language", "value": source_languages[0].value["language"]}) - if source_languages[0].value["dialect"] != "": - lang_filter_params.append({"code": "dialect", "value": source_languages[0].value["dialect"]}) - if target_languages is not None: - if isinstance(target_languages, Language): - target_languages = [target_languages] - if function == Function.TRANSLATION: - code = "targetlanguage" - lang_filter_params.append({"code": code, "value": target_languages[0].value["language"]}) - if sort_by is not None: - filter_params["sort"] = [{"dir": sort_order.value, "field": sort_by.value}] - if len(lang_filter_params) != 0: - filter_params["ioFilter"] = lang_filter_params - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} - - logging.info(f"Start service for POST Models Paginate - {url} - {headers} - {json.dumps(filter_params)}") - r = _request_with_retry("post", url, headers=headers, json=filter_params) - resp = r.json() - - except Exception as e: - error_message = f"Listing Models: Error in getting Models on Page {page_number}: {e}" - logging.error(error_message, exc_info=True) - return [] - if 200 <= r.status_code < 300: - logging.info(f"Listing Models: Status of getting Models on Page {page_number}: {r.status_code}") - all_models = resp["items"] - model_list = [cls._create_model_from_response(model_info_json) for model_info_json in all_models] - return model_list, resp["total"] - else: - error_message = f"Listing Models Error: Failed to retrieve models. Status Code: {r.status_code}. Error: {resp}" - logging.error(error_message) - raise Exception(error_message) - @classmethod def list( cls, @@ -249,7 +160,9 @@ def list( Returns: List[Model]: List of models based on given filters """ - models, total = cls._get_assets_from_page( + from aixplain.factories.model_factory.utils import get_assets_from_page + + models, total = get_assets_from_page( query, page_number, page_size, diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py new file mode 100644 index 00000000..01423795 --- /dev/null +++ b/aixplain/factories/model_factory/utils.py @@ -0,0 +1,142 @@ +import json +import logging +from aixplain.modules.model import Model +from aixplain.modules.model.llm_model import LLM +from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput +from aixplain.enums import DataType, Function, Language, OwnershipType, Supplier, SortBy, SortOrder +from aixplain.utils import config +from aixplain.utils.file_utils import _request_with_retry +from aixplain.enums.function import FunctionInputOutput +from datetime import datetime +from typing import Dict, Union, List, Optional, Tuple +from urllib.parse import urljoin + + +def create_model_from_response(response: Dict) -> Model: + """Converts response Json to 'Model' object + + Args: + response (Dict): Json from API + + Returns: + Model: Coverted 'Model' object + """ + if "api_key" not in response: + response["api_key"] = config.TEAM_API_KEY + + parameters = {} + if "params" in response: + for param in response["params"]: + if "language" in param["name"]: + parameters[param["name"]] = [w["value"] for w in param["values"]] + + function = Function(response["function"]["id"]) + inputs = [] + ModelClass = Model + if function == Function.TEXT_GENERATION: + ModelClass = LLM + elif function == Function.UTILITIES: + ModelClass = UtilityModel + inputs = [ + UtilityModelInput(name=param["name"], description=param.get("description", ""), type=DataType(param["dataType"])) + for param in response["params"] + ] + + created_at = None + if "createdAt" in response and response["createdAt"]: + created_at = datetime.fromisoformat(response["createdAt"].replace("Z", "+00:00")) + function_id = response["function"]["id"] + function = Function(function_id) + function_io = FunctionInputOutput.get(function_id, None) + input_params = {param["code"]: param for param in function_io["spec"]["params"]} + output_params = {param["code"]: param for param in function_io["spec"]["output"]} + + return ModelClass( + response["id"], + response["name"], + description=response.get("description", ""), + code=response.get("code", ""), + supplier=response["supplier"], + api_key=response["api_key"], + cost=response["pricing"], + function=function, + created_at=created_at, + parameters=parameters, + input_params=input_params, + output_params=output_params, + is_subscribed=True if "subscription" in response else False, + version=response["version"]["id"], + inputs=inputs, + ) + + +def get_assets_from_page( + query, + page_number: int, + page_size: int, + function: Function, + suppliers: Union[Supplier, List[Supplier]], + source_languages: Union[Language, List[Language]], + target_languages: Union[Language, List[Language]], + is_finetunable: bool = None, + ownership: Optional[Tuple[OwnershipType, List[OwnershipType]]] = None, + sort_by: Optional[SortBy] = None, + sort_order: SortOrder = SortOrder.ASCENDING, +) -> List[Model]: + try: + url = urljoin(config.BACKEND_URL, "sdk/models/paginate") + filter_params = {"q": query, "pageNumber": page_number, "pageSize": page_size} + if is_finetunable is not None: + filter_params["isFineTunable"] = is_finetunable + if function is not None: + filter_params["functions"] = [function.value] + if suppliers is not None: + if isinstance(suppliers, Supplier) is True: + suppliers = [suppliers] + filter_params["suppliers"] = [supplier.value["id"] for supplier in suppliers] + if ownership is not None: + if isinstance(ownership, OwnershipType) is True: + ownership = [ownership] + filter_params["ownership"] = [ownership_.value for ownership_ in ownership] + + lang_filter_params = [] + if source_languages is not None: + if isinstance(source_languages, Language): + source_languages = [source_languages] + if function == Function.TRANSLATION: + lang_filter_params.append({"code": "sourcelanguage", "value": source_languages[0].value["language"]}) + else: + lang_filter_params.append({"code": "language", "value": source_languages[0].value["language"]}) + if source_languages[0].value["dialect"] != "": + lang_filter_params.append({"code": "dialect", "value": source_languages[0].value["dialect"]}) + if target_languages is not None: + if isinstance(target_languages, Language): + target_languages = [target_languages] + if function == Function.TRANSLATION: + code = "targetlanguage" + lang_filter_params.append({"code": code, "value": target_languages[0].value["language"]}) + if sort_by is not None: + filter_params["sort"] = [{"dir": sort_order.value, "field": sort_by.value}] + if len(lang_filter_params) != 0: + filter_params["ioFilter"] = lang_filter_params + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + + logging.info(f"Start service for POST Models Paginate - {url} - {headers} - {json.dumps(filter_params)}") + r = _request_with_retry("post", url, headers=headers, json=filter_params) + resp = r.json() + + except Exception as e: + error_message = f"Listing Models: Error in getting Models on Page {page_number}: {e}" + logging.error(error_message, exc_info=True) + return [] + if 200 <= r.status_code < 300: + logging.info(f"Listing Models: Status of getting Models on Page {page_number}: {r.status_code}") + all_models = resp["items"] + from aixplain.factories.model_factory.utils import create_model_from_response + + model_list = [create_model_from_response(model_info_json) for model_info_json in all_models] + return model_list, resp["total"] + else: + error_message = f"Listing Models Error: Failed to retrieve models. Status Code: {r.status_code}. Error: {resp}" + logging.error(error_message) + raise Exception(error_message) diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py new file mode 100644 index 00000000..6835b56a --- /dev/null +++ b/aixplain/modules/model/utility_model.py @@ -0,0 +1,179 @@ +""" +Copyright 2024 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Author: Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli +Date: November 25th 2024 +Description: + Utility Model Class +""" +import logging +import os +import validators +from aixplain.enums import Function, Supplier, DataType +from aixplain.modules.model import Model +from aixplain.utils import config +from aixplain.utils.file_utils import _request_with_retry +from dataclasses import dataclass +from typing import Union, Optional, List, Text, Dict +from urllib.parse import urljoin + + +@dataclass +class UtilityModelInput: + name: Text + description: Text + type: DataType = DataType.TEXT + + def __post_init__(self): + self.validate_type() + + def validate_type(self): + if self.type not in [DataType.TEXT, DataType.BOOLEAN, DataType.NUMBER]: + raise ValueError("Utility Model Input type must be TEXT, BOOLEAN or NUMBER") + + def to_dict(self): + return {"name": self.name, "description": self.description, "type": self.type.value} + + +class UtilityModel(Model): + """Ready-to-use Utility Model. + + Attributes: + id (Text): ID of the Model + name (Text): Name of the Model + description (Text, optional): description of the model. Defaults to "". + api_key (Text, optional): API key of the Model. Defaults to None. + url (Text, optional): endpoint of the model. Defaults to config.MODELS_RUN_URL. + supplier (Union[Dict, Text, Supplier, int], optional): supplier of the asset. Defaults to "aiXplain". + version (Text, optional): version of the model. Defaults to "1.0". + function (Text, optional): model AI function. Defaults to None. + url (str): URL to run the model. + backend_url (str): URL of the backend. + pricing (Dict, optional): model price. Defaults to None. + **additional_info: Any additional Model info to be saved + """ + + def __init__( + self, + id: Text, + name: Text, + description: Text, + code: Text, + inputs: List[UtilityModelInput], + api_key: Optional[Text] = None, + supplier: Union[Dict, Text, Supplier, int] = "aiXplain", + version: Optional[Text] = None, + function: Optional[Function] = None, + is_subscribed: bool = False, + cost: Optional[Dict] = None, + **additional_info, + ) -> None: + """Utility Model Init + + Args: + id (Text): ID of the Model + name (Text): Name of the Model + description (Text): description of the model. + code (Text): code of the model. + inputs (List[UtilityModelInput]): inputs of the model. + api_key (Text, optional): API key of the Model. Defaults to None. + supplier (Union[Dict, Text, Supplier, int], optional): supplier of the asset. Defaults to "aiXplain". + version (Text, optional): version of the model. Defaults to "1.0". + function (Function, optional): model AI function. Defaults to None. + is_subscribed (bool, optional): Is the user subscribed. Defaults to False. + cost (Dict, optional): model price. Defaults to None. + **additional_info: Any additional Model info to be saved + """ + assert function == Function.UTILITIES, "Utility Model only supports 'utilities' function" + super().__init__( + id=id, + name=name, + description=description, + supplier=supplier, + version=version, + cost=cost, + function=function, + is_subscribed=is_subscribed, + api_key=api_key, + **additional_info, + ) + self.url = config.MODELS_RUN_URL + self.backend_url = config.BACKEND_URL + self.code = code + self.inputs = inputs + self.validate() + + def validate(self): + from aixplain.factories.file_factory import FileFactory + from uuid import uuid4 + + assert self.name and self.name.strip() != "", "Name is required" + assert self.description and self.description.strip() != "", "Description is required" + assert self.code and self.code.strip() != "", "Code is required" + assert self.inputs and len(self.inputs) > 0, "At least one input is required" + + self.code = FileFactory.to_link(self.code) + # store code in a temporary local path if it is not a valid URL or S3 path + if not validators.url(self.code) and not self.code.startswith("s3:"): + local_path = str(uuid4()) + with open(local_path, "w") as f: + f.write(self.code) + self.code = FileFactory.upload(local_path=local_path, is_temp=True) + os.remove(local_path) + + def to_dict(self): + return { + "name": self.name, + "description": self.description, + "inputs": [input.to_dict() for input in self.inputs], + "code": self.code, + "function": self.function.value, + } + + def update(self): + self.validate() + url = urljoin(self.backend_url, f"sdk/utilities/{self.id}") + headers = {"x-api-key": f"{self.api_key}", "Content-Type": "application/json"} + payload = self.to_dict() + try: + logging.info(f"Start service for PUT Utility Model - {url} - {headers} - {payload}") + r = _request_with_retry("put", url, headers=headers, json=payload) + response = r.json() + except Exception as e: + message = f"Utility Model Update Error: {e}" + logging.error(message) + raise Exception(f"{message}") + + if not 200 <= r.status_code < 300: + message = f"Utility Model Update Error: {response}" + logging.error(message) + raise Exception(f"{message}") + + def delete(self): + url = urljoin(self.backend_url, f"sdk/utilities/{self.id}") + headers = {"x-api-key": f"{self.api_key}", "Content-Type": "application/json"} + try: + logging.info(f"Start service for DELETE Utility Model - {url} - {headers}") + r = _request_with_retry("delete", url, headers=headers) + response = r.json() + except Exception: + message = "Utility Model Deletion Error: Make sure the utility model exists and you are the owner." + logging.error(message) + raise Exception(f"{message}") + + if r.status_code != 200: + message = f"Utility Model Deletion Error: {response}" + logging.error(message) + raise Exception(f"{message}") diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index a2463a8d..4431c135 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -164,6 +164,8 @@ def test_get_model_error_response(): def test_get_assets_from_page_error(): + from aixplain.factories.model_factory.utils import get_assets_from_page + with requests_mock.Mocker() as mock: query = "test-query" page_number = 0 @@ -175,7 +177,7 @@ def test_get_assets_from_page_error(): mock.post(url, headers=headers, json=error_response, status_code=500) with pytest.raises(Exception) as excinfo: - ModelFactory._get_assets_from_page( + get_assets_from_page( query=query, page_number=page_number, page_size=page_size, From 685db06a4f0e605857b140bc46e6673427fe9093 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Tue, 26 Nov 2024 15:50:48 -0300 Subject: [PATCH 2/2] Tests for utility models --- aixplain/factories/model_factory/__init__.py | 6 +- aixplain/modules/model/utility_model.py | 5 + .../model/run_utility_model_test.py | 33 +++++++ tests/unit/utility_test.py | 99 +++++++++++++++++++ 4 files changed, 142 insertions(+), 1 deletion(-) create mode 100644 tests/functional/model/run_utility_model_test.py create mode 100644 tests/unit/utility_test.py diff --git a/aixplain/factories/model_factory/__init__.py b/aixplain/factories/model_factory/__init__.py index 13db1fb4..75156426 100644 --- a/aixplain/factories/model_factory/__init__.py +++ b/aixplain/factories/model_factory/__init__.py @@ -42,7 +42,9 @@ class ModelFactory: backend_url = config.BACKEND_URL @classmethod - def create_utility_model(cls, name: Text, description: Text, inputs: List[UtilityModelInput], code: Text) -> UtilityModel: + def create_utility_model( + cls, name: Text, description: Text, inputs: List[UtilityModelInput], code: Text, output_description: Text + ) -> UtilityModel: """Create a utility model Args: @@ -50,6 +52,7 @@ def create_utility_model(cls, name: Text, description: Text, inputs: List[Utilit description (Text): description of the model inputs (List[UtilityModelInput]): inputs of the model code (Text): code of the model + output_description (Text): description of the output Returns: UtilityModel: created utility model @@ -62,6 +65,7 @@ def create_utility_model(cls, name: Text, description: Text, inputs: List[Utilit code=code, function=Function.UTILITIES, api_key=config.TEAM_API_KEY, + output_description=output_description, ) payload = utility_model.to_dict() url = urljoin(cls.backend_url, "sdk/utilities") diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py index 6835b56a..31bc6058 100644 --- a/aixplain/modules/model/utility_model.py +++ b/aixplain/modules/model/utility_model.py @@ -72,6 +72,7 @@ def __init__( description: Text, code: Text, inputs: List[UtilityModelInput], + output_description: Text, api_key: Optional[Text] = None, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", version: Optional[Text] = None, @@ -88,6 +89,7 @@ def __init__( description (Text): description of the model. code (Text): code of the model. inputs (List[UtilityModelInput]): inputs of the model. + output_description (Text): description of the output api_key (Text, optional): API key of the Model. Defaults to None. supplier (Union[Dict, Text, Supplier, int], optional): supplier of the asset. Defaults to "aiXplain". version (Text, optional): version of the model. Defaults to "1.0". @@ -113,6 +115,7 @@ def __init__( self.backend_url = config.BACKEND_URL self.code = code self.inputs = inputs + self.output_description = output_description self.validate() def validate(self): @@ -123,6 +126,7 @@ def validate(self): assert self.description and self.description.strip() != "", "Description is required" assert self.code and self.code.strip() != "", "Code is required" assert self.inputs and len(self.inputs) > 0, "At least one input is required" + assert self.output_description and self.output_description.strip() != "", "Output description is required" self.code = FileFactory.to_link(self.code) # store code in a temporary local path if it is not a valid URL or S3 path @@ -140,6 +144,7 @@ def to_dict(self): "inputs": [input.to_dict() for input in self.inputs], "code": self.code, "function": self.function.value, + "outputDescription": self.output_description, } def update(self): diff --git a/tests/functional/model/run_utility_model_test.py b/tests/functional/model/run_utility_model_test.py new file mode 100644 index 00000000..5887c4ca --- /dev/null +++ b/tests/functional/model/run_utility_model_test.py @@ -0,0 +1,33 @@ +from aixplain.factories import ModelFactory +from aixplain.modules.model.utility_model import UtilityModelInput +from aixplain.enums import DataType + + +def test_run_utility_model(): + inputs = [ + UtilityModelInput(name="inputA", description="input A is the only input", type=DataType.TEXT), + ] + + output_description = "An example is 'test'" + + utility_model = ModelFactory.create_utility_model( + name="test_script", + description="This is a test script", + inputs=inputs, + code="def main(inputA):\n\treturn inputA", + output_description=output_description, + ) + + assert utility_model.id is not None + + response = utility_model.run(data={"inputA": "test"}) + assert response.status == "SUCCESS" + assert response.data == "test" + + utility_model.code = "def main(inputA):\n\treturn 5" + utility_model.update() + response = utility_model.run(data={"inputA": "test"}) + assert response.status == "SUCCESS" + assert str(response.data) == "5" + + utility_model.delete() diff --git a/tests/unit/utility_test.py b/tests/unit/utility_test.py new file mode 100644 index 00000000..c1b7b9e1 --- /dev/null +++ b/tests/unit/utility_test.py @@ -0,0 +1,99 @@ +import pytest +import requests_mock +from aixplain.factories.model_factory import ModelFactory +from urllib.parse import urljoin +from aixplain.utils import config +from aixplain.enums import DataType, Function +from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput +from unittest.mock import patch + + +def test_utility_model(): + with requests_mock.Mocker() as mock: + with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="utility_model_test"): + with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="utility_model_test"): + mock.post(urljoin(config.BACKEND_URL, "sdk/utilities"), json={"id": "123"}) + utility_model = ModelFactory.create_utility_model( + name="utility_model_test", + description="utility_model_test", + code="utility_model_test", + inputs=[UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], + output_description="output_description", + ) + assert utility_model.id == "123" + assert utility_model.name == "utility_model_test" + assert utility_model.description == "utility_model_test" + assert utility_model.code == "utility_model_test" + assert utility_model.inputs == [ + UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT) + ] + assert utility_model.output_description == "output_description" + + +def test_utility_model_with_invalid_name(): + with pytest.raises(Exception) as exc_info: + ModelFactory.create_utility_model( + name="", + description="utility_model_test", + code="utility_model_test", + inputs=[], + output_description="output_description", + ) + assert str(exc_info.value) == "Name is required" + + +def test_utility_model_with_invalid_inputs(): + with pytest.raises(Exception) as exc_info: + ModelFactory.create_utility_model( + name="utility_model_test", + description="utility_model_test", + code="utility_model_test", + inputs=[], + output_description="output_description", + ) + assert str(exc_info.value) == "At least one input is required" + + +def test_utility_model_to_dict(): + with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="utility_model_test"): + with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="utility_model_test"): + utility_model = UtilityModel( + id="123", + name="utility_model_test", + description="utility_model_test", + code="utility_model_test", + output_description="output_description", + inputs=[UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], + function=Function.UTILITIES, + api_key=config.TEAM_API_KEY, + ) + assert utility_model.to_dict() == { + "name": "utility_model_test", + "description": "utility_model_test", + "inputs": [{"name": "originCode", "description": "originCode", "type": "text"}], + "code": "utility_model_test", + "function": "utilities", + "outputDescription": "output_description", + } + + +def test_update_utility_model(): + with requests_mock.Mocker() as mock: + with patch("aixplain.factories.file_factory.FileFactory.to_link", return_value="utility_model_test"): + with patch("aixplain.factories.file_factory.FileFactory.upload", return_value="utility_model_test"): + mock.put(urljoin(config.BACKEND_URL, "sdk/utilities/123"), json={"id": "123"}) + utility_model = UtilityModel( + id="123", + name="utility_model_test", + description="utility_model_test", + code="utility_model_test", + output_description="output_description", + inputs=[UtilityModelInput(name="originCode", description="originCode", type=DataType.TEXT)], + function=Function.UTILITIES, + api_key=config.TEAM_API_KEY, + ) + utility_model.description = "updated_description" + utility_model.update() + + assert utility_model.id == "123" + assert utility_model.description == "updated_description"