diff --git a/aixplain/factories/model_factory.py b/aixplain/factories/model_factory.py index cd7de970..9ed3138f 100644 --- a/aixplain/factories/model_factory.py +++ b/aixplain/factories/model_factory.py @@ -65,7 +65,7 @@ def _create_model_from_response(cls, response: Dict) -> Model: response["name"], supplier=response["supplier"], api_key=response["api_key"], - pricing=response["pricing"], + cost=response["pricing"], function=Function(response["function"]["id"]), parameters=parameters, is_subscribed=True if "subscription" in response else False, @@ -404,9 +404,11 @@ def onboard_model(cls, model_id: Text, image_tag: Text, image_hash: Text, api_ke message = "Your onboarding request has been submitted to an aiXplain specialist for finalization. We will notify you when the process is completed." logging.info(message) return response - + @classmethod - def deploy_huggingface_model(cls, name: Text, hf_repo_id: Text, hf_token: Optional[Text] = "", api_key: Optional[Text] = None) -> Dict: + def deploy_huggingface_model( + cls, name: Text, hf_repo_id: Text, hf_token: Optional[Text] = "", api_key: Optional[Text] = None + ) -> Dict: """Onboards and deploys a Hugging Face large language model. Args: @@ -433,20 +435,16 @@ def deploy_huggingface_model(cls, name: Text, hf_repo_id: Text, hf_token: Option "sourceLanguage": "en", }, "source": "huggingface", - "onboardingParams": { - "hf_model_name": model_name, - "hf_supplier": supplier, - "hf_token": hf_token - } + "onboardingParams": {"hf_model_name": model_name, "hf_supplier": supplier, "hf_token": hf_token}, } response = _request_with_retry("post", deploy_url, headers=headers, json=body) logging.debug(response.text) response_dicts = json.loads(response.text) return response_dicts - + @classmethod def get_huggingface_model_status(cls, model_id: Text, api_key: Optional[Text] = None): - """Gets the on-boarding status of a Hugging Face model with ID MODEL_ID. + """Gets the on-boarding status of a Hugging Face model with ID MODEL_ID. Args: model_id (Text): The model's ID as returned by DEPLOY_HUGGINGFACE_MODEL @@ -466,6 +464,6 @@ def get_huggingface_model_status(cls, model_id: Text, api_key: Optional[Text] = "status": response_dicts["status"], "name": response_dicts["name"], "id": response_dicts["id"], - "pricing": response_dicts["pricing"] + "pricing": response_dicts["pricing"], } - return ret_dict \ No newline at end of file + return ret_dict diff --git a/aixplain/modules/asset.py b/aixplain/modules/asset.py index 34fea4e4..52b79912 100644 --- a/aixplain/modules/asset.py +++ b/aixplain/modules/asset.py @@ -36,7 +36,7 @@ def __init__( version: Text = "1.0", license: Optional[License] = None, privacy: Privacy = Privacy.PRIVATE, - cost: float = 0, + cost: Optional[Union[Dict, float]] = None, ) -> None: """Create an Asset with the necessary information @@ -46,6 +46,7 @@ def __init__( description (Text): Description of the Asset supplier (Union[Dict, Text, Supplier, int], optional): supplier of the asset. Defaults to "aiXplain". version (Optional[Text], optional): asset version. Defaults to "1.0". + cost (Optional[Union[Dict, float]], optional): asset price. Defaults to None. """ self.id = id self.name = name diff --git a/aixplain/modules/metric.py b/aixplain/modules/metric.py index 04a0bdd7..d591772b 100644 --- a/aixplain/modules/metric.py +++ b/aixplain/modules/metric.py @@ -61,12 +61,10 @@ def __init__( supplier (Text): author of the Metric is_reference_required (bool): does the metric use reference is_source_required (bool): does the metric use source - cost (float): cost of the metric + cost (float): price of the metric normalization_options(list, []) **additional_info: Any additional Metric info to be saved """ - - super().__init__(id, name, description="", supplier=supplier, version="1.0", cost=cost) self.is_source_required = is_source_required self.is_reference_required = is_reference_required @@ -76,7 +74,7 @@ def __init__( def __repr__(self) -> str: return f"" - + def add_normalization_options(self, normalization_options: List[str]): """Add a given set of normalization options to be used while benchmarking @@ -85,7 +83,12 @@ def add_normalization_options(self, normalization_options: List[str]): """ self.normalization_options.append(normalization_options) - def run(self, hypothesis: Optional[Union[str, List[str]]]=None, source: Optional[Union[str, List[str]]]=None, reference: Optional[Union[str, List[str]]]=None): + def run( + self, + hypothesis: Optional[Union[str, List[str]]] = None, + source: Optional[Union[str, List[str]]] = None, + reference: Optional[Union[str, List[str]]] = None, + ): """Run the metric to calculate the scores. Args: @@ -94,6 +97,7 @@ def run(self, hypothesis: Optional[Union[str, List[str]]]=None, source: Optional reference (Optional[Union[str, List[str]]], optional): Can give a single reference or a list of references for metric calculation. Defaults to None. """ from aixplain.factories.model_factory import ModelFactory + model = ModelFactory.get(self.id) payload = { "function": self.function, @@ -115,4 +119,3 @@ def run(self, hypothesis: Optional[Union[str, List[str]]]=None, source: Optional reference = [[ref] for ref in reference] payload["references"] = reference return model.run(payload) - diff --git a/aixplain/modules/model.py b/aixplain/modules/model.py index fc3a82cd..983737c7 100644 --- a/aixplain/modules/model.py +++ b/aixplain/modules/model.py @@ -48,6 +48,7 @@ class Model(Asset): function (Text, optional): model AI function. Defaults to None. url (str): URL to run the model. backend_url (str): URL of the backend. + pricing (Dict, optional): model price. Defaults to None. **additional_info: Any additional Model info to be saved """ @@ -61,6 +62,7 @@ def __init__( version: Optional[Text] = None, function: Optional[Text] = None, is_subscribed: bool = False, + cost: Optional[Dict] = None, **additional_info, ) -> None: """Model Init @@ -74,9 +76,10 @@ def __init__( version (Text, optional): version of the model. Defaults to "1.0". function (Text, optional): model AI function. Defaults to None. is_subscribed (bool, optional): Is the user subscribed. Defaults to False. + cost (Dict, optional): model price. Defaults to None. **additional_info: Any additional Model info to be saved """ - super().__init__(id, name, description, supplier, version) + super().__init__(id, name, description, supplier, version, cost=cost) self.api_key = api_key self.additional_info = additional_info self.url = config.MODELS_RUN_URL @@ -264,6 +267,7 @@ def check_finetune_status(self, after_epoch: Optional[int] = None): """ from aixplain.enums.asset_status import AssetStatus from aixplain.modules.finetune.status import FinetuneStatus + headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} resp = None try: @@ -274,7 +278,7 @@ def check_finetune_status(self, after_epoch: Optional[int] = None): finetune_status = AssetStatus(resp["finetuneStatus"]) model_status = AssetStatus(resp["modelStatus"]) logs = sorted(resp["logs"], key=lambda x: float(x["epoch"])) - + target_epoch = None if after_epoch is not None: logs = [log for log in logs if float(log["epoch"]) > after_epoch] @@ -282,7 +286,7 @@ def check_finetune_status(self, after_epoch: Optional[int] = None): target_epoch = float(logs[0]["epoch"]) elif len(logs) > 0: target_epoch = float(logs[-1]["epoch"]) - + if target_epoch is not None: log = None for log_ in logs: @@ -294,7 +298,7 @@ def check_finetune_status(self, after_epoch: Optional[int] = None): log["trainLoss"] = log_["trainLoss"] if log_["evalLoss"] is not None: log["evalLoss"] = log_["evalLoss"] - + status = FinetuneStatus( status=finetune_status, model_status=model_status, diff --git a/tests/functional/general_assets/asset_functional_test.py b/tests/functional/general_assets/asset_functional_test.py index 6a9dceda..93a3b297 100644 --- a/tests/functional/general_assets/asset_functional_test.py +++ b/tests/functional/general_assets/asset_functional_test.py @@ -82,8 +82,8 @@ def test_model_sort(): prev_model = models[idx - 1] model = models[idx] - prev_model_price = prev_model.additional_info["pricing"]["price"] - model_price = model.additional_info["pricing"]["price"] + prev_model_price = prev_model.cost["price"] + model_price = model.cost["price"] assert prev_model_price >= model_price