Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions aixplain/factories/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _create_model_from_response(cls, response: Dict) -> Model:
response["name"],
supplier=response["supplier"],
api_key=response["api_key"],
pricing=response["pricing"],
cost=response["pricing"],
function=Function(response["function"]["id"]),
parameters=parameters,
is_subscribed=True if "subscription" in response else False,
Expand Down Expand Up @@ -404,9 +404,11 @@ def onboard_model(cls, model_id: Text, image_tag: Text, image_hash: Text, api_ke
message = "Your onboarding request has been submitted to an aiXplain specialist for finalization. We will notify you when the process is completed."
logging.info(message)
return response

@classmethod
def deploy_huggingface_model(cls, name: Text, hf_repo_id: Text, hf_token: Optional[Text] = "", api_key: Optional[Text] = None) -> Dict:
def deploy_huggingface_model(
cls, name: Text, hf_repo_id: Text, hf_token: Optional[Text] = "", api_key: Optional[Text] = None
) -> Dict:
"""Onboards and deploys a Hugging Face large language model.

Args:
Expand All @@ -433,20 +435,16 @@ def deploy_huggingface_model(cls, name: Text, hf_repo_id: Text, hf_token: Option
"sourceLanguage": "en",
},
"source": "huggingface",
"onboardingParams": {
"hf_model_name": model_name,
"hf_supplier": supplier,
"hf_token": hf_token
}
"onboardingParams": {"hf_model_name": model_name, "hf_supplier": supplier, "hf_token": hf_token},
}
response = _request_with_retry("post", deploy_url, headers=headers, json=body)
logging.debug(response.text)
response_dicts = json.loads(response.text)
return response_dicts

@classmethod
def get_huggingface_model_status(cls, model_id: Text, api_key: Optional[Text] = None):
"""Gets the on-boarding status of a Hugging Face model with ID MODEL_ID.
"""Gets the on-boarding status of a Hugging Face model with ID MODEL_ID.

Args:
model_id (Text): The model's ID as returned by DEPLOY_HUGGINGFACE_MODEL
Expand All @@ -466,6 +464,6 @@ def get_huggingface_model_status(cls, model_id: Text, api_key: Optional[Text] =
"status": response_dicts["status"],
"name": response_dicts["name"],
"id": response_dicts["id"],
"pricing": response_dicts["pricing"]
"pricing": response_dicts["pricing"],
}
return ret_dict
return ret_dict
3 changes: 2 additions & 1 deletion aixplain/modules/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
version: Text = "1.0",
license: Optional[License] = None,
privacy: Privacy = Privacy.PRIVATE,
cost: float = 0,
cost: Optional[Union[Dict, float]] = None,
) -> None:
"""Create an Asset with the necessary information

Expand All @@ -46,6 +46,7 @@ def __init__(
description (Text): Description of the Asset
supplier (Union[Dict, Text, Supplier, int], optional): supplier of the asset. Defaults to "aiXplain".
version (Optional[Text], optional): asset version. Defaults to "1.0".
cost (Optional[Union[Dict, float]], optional): asset price. Defaults to None.
"""
self.id = id
self.name = name
Expand Down
15 changes: 9 additions & 6 deletions aixplain/modules/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,10 @@ def __init__(
supplier (Text): author of the Metric
is_reference_required (bool): does the metric use reference
is_source_required (bool): does the metric use source
cost (float): cost of the metric
cost (float): price of the metric
normalization_options(list, [])
**additional_info: Any additional Metric info to be saved
"""


super().__init__(id, name, description="", supplier=supplier, version="1.0", cost=cost)
self.is_source_required = is_source_required
self.is_reference_required = is_reference_required
Expand All @@ -76,7 +74,7 @@ def __init__(

def __repr__(self) -> str:
return f"<Metric {self.name}>"

def add_normalization_options(self, normalization_options: List[str]):
"""Add a given set of normalization options to be used while benchmarking

Expand All @@ -85,7 +83,12 @@ def add_normalization_options(self, normalization_options: List[str]):
"""
self.normalization_options.append(normalization_options)

def run(self, hypothesis: Optional[Union[str, List[str]]]=None, source: Optional[Union[str, List[str]]]=None, reference: Optional[Union[str, List[str]]]=None):
def run(
self,
hypothesis: Optional[Union[str, List[str]]] = None,
source: Optional[Union[str, List[str]]] = None,
reference: Optional[Union[str, List[str]]] = None,
):
"""Run the metric to calculate the scores.

Args:
Expand All @@ -94,6 +97,7 @@ def run(self, hypothesis: Optional[Union[str, List[str]]]=None, source: Optional
reference (Optional[Union[str, List[str]]], optional): Can give a single reference or a list of references for metric calculation. Defaults to None.
"""
from aixplain.factories.model_factory import ModelFactory

model = ModelFactory.get(self.id)
payload = {
"function": self.function,
Expand All @@ -115,4 +119,3 @@ def run(self, hypothesis: Optional[Union[str, List[str]]]=None, source: Optional
reference = [[ref] for ref in reference]
payload["references"] = reference
return model.run(payload)

12 changes: 8 additions & 4 deletions aixplain/modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class Model(Asset):
function (Text, optional): model AI function. Defaults to None.
url (str): URL to run the model.
backend_url (str): URL of the backend.
pricing (Dict, optional): model price. Defaults to None.
**additional_info: Any additional Model info to be saved
"""

Expand All @@ -61,6 +62,7 @@ def __init__(
version: Optional[Text] = None,
function: Optional[Text] = None,
is_subscribed: bool = False,
cost: Optional[Dict] = None,
**additional_info,
) -> None:
"""Model Init
Expand All @@ -74,9 +76,10 @@ def __init__(
version (Text, optional): version of the model. Defaults to "1.0".
function (Text, optional): model AI function. Defaults to None.
is_subscribed (bool, optional): Is the user subscribed. Defaults to False.
cost (Dict, optional): model price. Defaults to None.
**additional_info: Any additional Model info to be saved
"""
super().__init__(id, name, description, supplier, version)
super().__init__(id, name, description, supplier, version, cost=cost)
self.api_key = api_key
self.additional_info = additional_info
self.url = config.MODELS_RUN_URL
Expand Down Expand Up @@ -264,6 +267,7 @@ def check_finetune_status(self, after_epoch: Optional[int] = None):
"""
from aixplain.enums.asset_status import AssetStatus
from aixplain.modules.finetune.status import FinetuneStatus

headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}
resp = None
try:
Expand All @@ -274,15 +278,15 @@ def check_finetune_status(self, after_epoch: Optional[int] = None):
finetune_status = AssetStatus(resp["finetuneStatus"])
model_status = AssetStatus(resp["modelStatus"])
logs = sorted(resp["logs"], key=lambda x: float(x["epoch"]))

target_epoch = None
if after_epoch is not None:
logs = [log for log in logs if float(log["epoch"]) > after_epoch]
if len(logs) > 0:
target_epoch = float(logs[0]["epoch"])
elif len(logs) > 0:
target_epoch = float(logs[-1]["epoch"])

if target_epoch is not None:
log = None
for log_ in logs:
Expand All @@ -294,7 +298,7 @@ def check_finetune_status(self, after_epoch: Optional[int] = None):
log["trainLoss"] = log_["trainLoss"]
if log_["evalLoss"] is not None:
log["evalLoss"] = log_["evalLoss"]

status = FinetuneStatus(
status=finetune_status,
model_status=model_status,
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/general_assets/asset_functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def test_model_sort():
prev_model = models[idx - 1]
model = models[idx]

prev_model_price = prev_model.additional_info["pricing"]["price"]
model_price = model.additional_info["pricing"]["price"]
prev_model_price = prev_model.cost["price"]
model_price = model.cost["price"]
assert prev_model_price >= model_price


Expand Down