Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 34 additions & 17 deletions aixplain/factories/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from aixplain.utils.file_utils import _request_with_retry
from urllib.parse import urljoin
from warnings import warn
from aixplain.enums.function import FunctionInputOutput


class ModelFactory:
Expand Down Expand Up @@ -66,6 +67,12 @@ def _create_model_from_response(cls, response: Dict) -> Model:
if function == Function.TEXT_GENERATION:
ModelClass = LLM

function_id = response["function"]["id"]
function = Function(function_id)
function_io = FunctionInputOutput.get(function_id, None)
input_params = {param["code"]: param for param in function_io["spec"]["params"]}
output_params = {param["code"]: param for param in function_io["spec"]["output"]}

return ModelClass(
response["id"],
response["name"],
Expand All @@ -74,6 +81,8 @@ def _create_model_from_response(cls, response: Dict) -> Model:
cost=response["pricing"],
function=function,
parameters=parameters,
input_params=input_params,
output_params=output_params,
is_subscribed=True if "subscription" in response else False,
version=response["version"]["id"],
)
Expand Down Expand Up @@ -270,7 +279,7 @@ def list_host_machines(cls, api_key: Optional[Text] = None) -> List[Dict]:
for dictionary in response_dicts:
del dictionary["id"]
return response_dicts

@classmethod
def list_gpus(cls, api_key: Optional[Text] = None) -> List[List[Text]]:
"""List GPU names on which you can host your language model.
Expand Down Expand Up @@ -335,7 +344,7 @@ def create_asset_repo(
input_modality: Text,
output_modality: Text,
documentation_url: Optional[Text] = "",
api_key: Optional[Text] = None
api_key: Optional[Text] = None,
) -> Dict:
"""Creates an image repository for this model and registers it in the
platform backend.
Expand All @@ -362,7 +371,7 @@ def create_asset_repo(
function_id = function_dict["id"]
if function_id is None:
raise Exception(f"Invalid function name {function}")
create_url = urljoin(config.BACKEND_URL, f"sdk/models/onboard")
create_url = urljoin(config.BACKEND_URL, "sdk/models/onboard")
logging.debug(f"URL: {create_url}")
if api_key:
headers = {"x-api-key": f"{api_key}", "Content-Type": "application/json"}
Expand All @@ -373,19 +382,14 @@ def create_asset_repo(
"model": {
"name": name,
"description": description,
"connectionType": [
"synchronous"
],
"connectionType": ["synchronous"],
"function": function_id,
"modalities": [
f"{input_modality}-{output_modality}"
],
"modalities": [f"{input_modality}-{output_modality}"],
"documentationUrl": documentation_url,
"sourceLanguage": source_language
"sourceLanguage": source_language,
},
"source": "aixplain-ecr",
"onboardingParams": {
}
"onboardingParams": {},
}
logging.debug(f"Body: {str(payload)}")
response = _request_with_retry("post", create_url, headers=headers, json=payload)
Expand All @@ -412,12 +416,18 @@ def asset_repo_login(cls, api_key: Optional[Text] = None) -> Dict:
else:
headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"}
response = _request_with_retry("post", login_url, headers=headers)
print(f"Response: {response}")
response_dict = json.loads(response.text)
return response_dict

@classmethod
def onboard_model(cls, model_id: Text, image_tag: Text, image_hash: Text, host_machine: Optional[Text] = "", api_key: Optional[Text] = None) -> Dict:
def onboard_model(
cls,
model_id: Text,
image_tag: Text,
image_hash: Text,
host_machine: Optional[Text] = "",
api_key: Optional[Text] = None,
) -> Dict:
"""Onboard a model after its image has been pushed to ECR.

Args:
Expand Down Expand Up @@ -446,7 +456,14 @@ def onboard_model(cls, model_id: Text, image_tag: Text, image_hash: Text, host_m
return response

@classmethod
def deploy_huggingface_model(cls, name: Text, hf_repo_id: Text, revision: Optional[Text] = "", hf_token: Optional[Text] = "", api_key: Optional[Text] = None) -> Dict:
def deploy_huggingface_model(
cls,
name: Text,
hf_repo_id: Text,
revision: Optional[Text] = "",
hf_token: Optional[Text] = "",
api_key: Optional[Text] = None,
) -> Dict:
"""Onboards and deploys a Hugging Face large language model.

Args:
Expand Down Expand Up @@ -477,8 +494,8 @@ def deploy_huggingface_model(cls, name: Text, hf_repo_id: Text, revision: Option
"hf_supplier": supplier,
"hf_model_name": model_name,
"hf_token": hf_token,
"revision": revision
}
"revision": revision,
},
}
response = _request_with_retry("post", deploy_url, headers=headers, json=body)
logging.debug(response.text)
Expand Down
8 changes: 7 additions & 1 deletion aixplain/modules/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class Model(Asset):
backend_url (str): URL of the backend.
pricing (Dict, optional): model price. Defaults to None.
**additional_info: Any additional Model info to be saved
input_params (Dict, optional): input parameters for the function.
output_params (Dict, optional): output parameters for the function.
"""

def __init__(
Expand All @@ -61,6 +63,8 @@ def __init__(
function: Optional[Function] = None,
is_subscribed: bool = False,
cost: Optional[Dict] = None,
input_params: Optional[Dict] = None,
output_params: Optional[Dict] = None,
**additional_info,
) -> None:
"""Model Init
Expand All @@ -84,6 +88,8 @@ def __init__(
self.backend_url = config.BACKEND_URL
self.function = function
self.is_subscribed = is_subscribed
self.input_params = input_params
self.output_params = output_params

def to_dict(self) -> Dict:
"""Get the model info as a Dictionary
Expand All @@ -92,7 +98,7 @@ def to_dict(self) -> Dict:
Dict: Model Information
"""
clean_additional_info = {k: v for k, v in self.additional_info.items() if v is not None}
return {"id": self.id, "name": self.name, "supplier": self.supplier, "additional_info": clean_additional_info}
return {"id": self.id, "name": self.name, "supplier": self.supplier, "additional_info": clean_additional_info, "input_params": self.input_params,"output_params": self.output_params,}

def __repr__(self):
try:
Expand Down
22 changes: 22 additions & 0 deletions tests/functional/general_assets/asset_functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,25 @@ def test_llm_instantiation():
"""Test that the LLM model is correctly instantiated."""
models = ModelFactory.list(function=Function.TEXT_GENERATION)["results"]
assert isinstance(models[0], LLM)


def test_model_io():
model_id = "64aee5824d34b1221e70ac07"
model = ModelFactory.get(model_id)

expected_input = {
"text": {
"name": "Text Prompt",
"code": "text",
"required": True,
"isFixed": False,
"dataType": "text",
"dataSubType": "text",
"multipleValues": False,
"defaultValues": [],
}
}
expected_output = {"data": {"name": "Generated Image", "code": "data", "defaultValue": [], "dataType": "image"}}

assert model.input_params == expected_input
assert model.output_params == expected_output
14 changes: 6 additions & 8 deletions tests/unit/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,24 +63,22 @@ def test_failed_poll():
@pytest.mark.parametrize(
"status_code,error_message",
[
(401,"Unauthorized API key: Please verify the spelling of the API key and its current validity."),
(465,"Subscription-related error: Please ensure that your subscription is active and has not expired."),
(475,"Billing-related error: Please ensure you have enough credits to run this model. "),
(401, "Unauthorized API key: Please verify the spelling of the API key and its current validity."),
(465, "Subscription-related error: Please ensure that your subscription is active and has not expired."),
(475, "Billing-related error: Please ensure you have enough credits to run this model. "),
(485, "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access."),
(495, "Validation-related error: Please ensure all required fields are provided and correctly formatted."),
(501, "Status 501: Unspecified error: An unspecified error occurred while processing your request."),

],
)

def test_run_async_errors(status_code, error_message):
base_url = config.MODELS_RUN_URL
model_id = "model-id"
execute_url = urljoin(base_url, f"execute/{model_id}")

with requests_mock.Mocker() as mock:
mock.post(execute_url, status_code=status_code)
test_model = Model(id=model_id, name="Test Model",url=base_url)
test_model = Model(id=model_id, name="Test Model", url=base_url)
response = test_model.run_async(data="input_data")
assert response["status"] == "FAILED"
assert response["error_message"] == error_message
assert response["error_message"] == error_message