Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aixplain/enums/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class DataType(str, Enum):
VIDEO = "video"
EMBEDDING = "embedding"
NUMBER = "number"
BOOLEAN = "boolean"

def __str__(self):
return self._value_
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,11 @@
import json
import logging
from aixplain.modules.model import Model
from aixplain.modules.model.llm_model import LLM
from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput
from aixplain.enums import Function, Language, OwnershipType, Supplier, SortBy, SortOrder
from aixplain.utils import config
from aixplain.utils.file_utils import _request_with_retry
from urllib.parse import urljoin
from warnings import warn
from aixplain.enums.function import FunctionInputOutput
from datetime import datetime


class ModelFactory:
Expand All @@ -44,53 +41,52 @@ class ModelFactory:
backend_url = config.BACKEND_URL

@classmethod
def _create_model_from_response(cls, response: Dict) -> Model:
"""Converts response Json to 'Model' object
def create_utility_model(
cls, name: Text, description: Text, inputs: List[UtilityModelInput], code: Text, output_description: Text
) -> UtilityModel:
"""Create a utility model

Args:
response (Dict): Json from API
name (Text): name of the model
description (Text): description of the model
inputs (List[UtilityModelInput]): inputs of the model
code (Text): code of the model
output_description (Text): description of the output

Returns:
Model: Coverted 'Model' object
UtilityModel: created utility model
"""
if "api_key" not in response:
response["api_key"] = config.TEAM_API_KEY

parameters = {}
if "params" in response:
for param in response["params"]:
if "language" in param["name"]:
parameters[param["name"]] = [w["value"] for w in param["values"]]

function = Function(response["function"]["id"])
ModelClass = Model
if function == Function.TEXT_GENERATION:
ModelClass = LLM

created_at = None
if "createdAt" in response and response["createdAt"]:
created_at = datetime.fromisoformat(response["createdAt"].replace("Z", "+00:00"))
function_id = response["function"]["id"]
function = Function(function_id)
function_io = FunctionInputOutput.get(function_id, None)
input_params = {param["code"]: param for param in function_io["spec"]["params"]}
output_params = {param["code"]: param for param in function_io["spec"]["output"]}

return ModelClass(
response["id"],
response["name"],
description=response.get("description", ""),
supplier=response["supplier"],
api_key=response["api_key"],
cost=response["pricing"],
function=function,
created_at=created_at,
parameters=parameters,
input_params=input_params,
output_params=output_params,
is_subscribed=True if "subscription" in response else False,
version=response["version"]["id"],
utility_model = UtilityModel(
id="",
name=name,
description=description,
inputs=inputs,
code=code,
function=Function.UTILITIES,
api_key=config.TEAM_API_KEY,
output_description=output_description,
)
payload = utility_model.to_dict()
url = urljoin(cls.backend_url, "sdk/utilities")
headers = {"x-api-key": f"{config.TEAM_API_KEY}", "Content-Type": "application/json"}
try:
logging.info(f"Start service for POST Utility Model - {url} - {headers} - {payload}")
r = _request_with_retry("post", url, headers=headers, json=payload)
resp = r.json()
except Exception as e:
logging.error(f"Error creating utility model: {e}")
raise e

if 200 <= r.status_code < 300:
utility_model.id = resp["id"]
logging.info(f"Utility Model Creation: Model {utility_model.id} instantiated.")
return utility_model
else:
error_message = (
f"Utility Model Creation: Failed to create utility model. Status Code: {r.status_code}. Error: {resp}"
)
logging.error(error_message)
raise Exception(error_message)

@classmethod
def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model:
Expand Down Expand Up @@ -125,95 +121,16 @@ def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model:
resp["api_key"] = config.TEAM_API_KEY
if api_key is not None:
resp["api_key"] = api_key
model = cls._create_model_from_response(resp)
from aixplain.factories.model_factory.utils import create_model_from_response

model = create_model_from_response(resp)
logging.info(f"Model Creation: Model {model_id} instantiated.")
return model
else:
error_message = f"Model GET Error: Failed to retrieve model {model_id}. Status Code: {r.status_code}. Error: {resp}"
logging.error(error_message)
raise Exception(error_message)

@classmethod
def create_asset_from_id(cls, model_id: Text) -> Model:
warn(
'This method will be deprecated in the next versions of the SDK. Use "get" instead.',
DeprecationWarning,
stacklevel=2,
)
return cls.get(model_id)

@classmethod
def _get_assets_from_page(
cls,
query,
page_number: int,
page_size: int,
function: Function,
suppliers: Union[Supplier, List[Supplier]],
source_languages: Union[Language, List[Language]],
target_languages: Union[Language, List[Language]],
is_finetunable: bool = None,
ownership: Optional[Tuple[OwnershipType, List[OwnershipType]]] = None,
sort_by: Optional[SortBy] = None,
sort_order: SortOrder = SortOrder.ASCENDING,
) -> List[Model]:
try:
url = urljoin(cls.backend_url, "sdk/models/paginate")
filter_params = {"q": query, "pageNumber": page_number, "pageSize": page_size}
if is_finetunable is not None:
filter_params["isFineTunable"] = is_finetunable
if function is not None:
filter_params["functions"] = [function.value]
if suppliers is not None:
if isinstance(suppliers, Supplier) is True:
suppliers = [suppliers]
filter_params["suppliers"] = [supplier.value["id"] for supplier in suppliers]
if ownership is not None:
if isinstance(ownership, OwnershipType) is True:
ownership = [ownership]
filter_params["ownership"] = [ownership_.value for ownership_ in ownership]

lang_filter_params = []
if source_languages is not None:
if isinstance(source_languages, Language):
source_languages = [source_languages]
if function == Function.TRANSLATION:
lang_filter_params.append({"code": "sourcelanguage", "value": source_languages[0].value["language"]})
else:
lang_filter_params.append({"code": "language", "value": source_languages[0].value["language"]})
if source_languages[0].value["dialect"] != "":
lang_filter_params.append({"code": "dialect", "value": source_languages[0].value["dialect"]})
if target_languages is not None:
if isinstance(target_languages, Language):
target_languages = [target_languages]
if function == Function.TRANSLATION:
code = "targetlanguage"
lang_filter_params.append({"code": code, "value": target_languages[0].value["language"]})
if sort_by is not None:
filter_params["sort"] = [{"dir": sort_order.value, "field": sort_by.value}]
if len(lang_filter_params) != 0:
filter_params["ioFilter"] = lang_filter_params

headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"}

logging.info(f"Start service for POST Models Paginate - {url} - {headers} - {json.dumps(filter_params)}")
r = _request_with_retry("post", url, headers=headers, json=filter_params)
resp = r.json()

except Exception as e:
error_message = f"Listing Models: Error in getting Models on Page {page_number}: {e}"
logging.error(error_message, exc_info=True)
return []
if 200 <= r.status_code < 300:
logging.info(f"Listing Models: Status of getting Models on Page {page_number}: {r.status_code}")
all_models = resp["items"]
model_list = [cls._create_model_from_response(model_info_json) for model_info_json in all_models]
return model_list, resp["total"]
else:
error_message = f"Listing Models Error: Failed to retrieve models. Status Code: {r.status_code}. Error: {resp}"
logging.error(error_message)
raise Exception(error_message)

@classmethod
def list(
cls,
Expand Down Expand Up @@ -244,7 +161,9 @@ def list(
Returns:
List[Model]: List of models based on given filters
"""
models, total = cls._get_assets_from_page(
from aixplain.factories.model_factory.utils import get_assets_from_page

models, total = get_assets_from_page(
query,
page_number,
page_size,
Expand Down
142 changes: 142 additions & 0 deletions aixplain/factories/model_factory/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import json
import logging
from aixplain.modules.model import Model
from aixplain.modules.model.llm_model import LLM
from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput
from aixplain.enums import DataType, Function, Language, OwnershipType, Supplier, SortBy, SortOrder
from aixplain.utils import config
from aixplain.utils.file_utils import _request_with_retry
from aixplain.enums.function import FunctionInputOutput
from datetime import datetime
from typing import Dict, Union, List, Optional, Tuple
from urllib.parse import urljoin


def create_model_from_response(response: Dict) -> Model:
"""Converts response Json to 'Model' object

Args:
response (Dict): Json from API

Returns:
Model: Coverted 'Model' object
"""
if "api_key" not in response:
response["api_key"] = config.TEAM_API_KEY

parameters = {}
if "params" in response:
for param in response["params"]:
if "language" in param["name"]:
parameters[param["name"]] = [w["value"] for w in param["values"]]

function = Function(response["function"]["id"])
inputs = []
ModelClass = Model
if function == Function.TEXT_GENERATION:
ModelClass = LLM
elif function == Function.UTILITIES:
ModelClass = UtilityModel
inputs = [
UtilityModelInput(name=param["name"], description=param.get("description", ""), type=DataType(param["dataType"]))
for param in response["params"]
]

created_at = None
if "createdAt" in response and response["createdAt"]:
created_at = datetime.fromisoformat(response["createdAt"].replace("Z", "+00:00"))
function_id = response["function"]["id"]
function = Function(function_id)
function_io = FunctionInputOutput.get(function_id, None)
input_params = {param["code"]: param for param in function_io["spec"]["params"]}
output_params = {param["code"]: param for param in function_io["spec"]["output"]}

return ModelClass(
response["id"],
response["name"],
description=response.get("description", ""),
code=response.get("code", ""),
supplier=response["supplier"],
api_key=response["api_key"],
cost=response["pricing"],
function=function,
created_at=created_at,
parameters=parameters,
input_params=input_params,
output_params=output_params,
is_subscribed=True if "subscription" in response else False,
version=response["version"]["id"],
inputs=inputs,
)


def get_assets_from_page(
query,
page_number: int,
page_size: int,
function: Function,
suppliers: Union[Supplier, List[Supplier]],
source_languages: Union[Language, List[Language]],
target_languages: Union[Language, List[Language]],
is_finetunable: bool = None,
ownership: Optional[Tuple[OwnershipType, List[OwnershipType]]] = None,
sort_by: Optional[SortBy] = None,
sort_order: SortOrder = SortOrder.ASCENDING,
) -> List[Model]:
try:
url = urljoin(config.BACKEND_URL, "sdk/models/paginate")
filter_params = {"q": query, "pageNumber": page_number, "pageSize": page_size}
if is_finetunable is not None:
filter_params["isFineTunable"] = is_finetunable
if function is not None:
filter_params["functions"] = [function.value]
if suppliers is not None:
if isinstance(suppliers, Supplier) is True:
suppliers = [suppliers]
filter_params["suppliers"] = [supplier.value["id"] for supplier in suppliers]
if ownership is not None:
if isinstance(ownership, OwnershipType) is True:
ownership = [ownership]
filter_params["ownership"] = [ownership_.value for ownership_ in ownership]

lang_filter_params = []
if source_languages is not None:
if isinstance(source_languages, Language):
source_languages = [source_languages]
if function == Function.TRANSLATION:
lang_filter_params.append({"code": "sourcelanguage", "value": source_languages[0].value["language"]})
else:
lang_filter_params.append({"code": "language", "value": source_languages[0].value["language"]})
if source_languages[0].value["dialect"] != "":
lang_filter_params.append({"code": "dialect", "value": source_languages[0].value["dialect"]})
if target_languages is not None:
if isinstance(target_languages, Language):
target_languages = [target_languages]
if function == Function.TRANSLATION:
code = "targetlanguage"
lang_filter_params.append({"code": code, "value": target_languages[0].value["language"]})
if sort_by is not None:
filter_params["sort"] = [{"dir": sort_order.value, "field": sort_by.value}]
if len(lang_filter_params) != 0:
filter_params["ioFilter"] = lang_filter_params
headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"}

logging.info(f"Start service for POST Models Paginate - {url} - {headers} - {json.dumps(filter_params)}")
r = _request_with_retry("post", url, headers=headers, json=filter_params)
resp = r.json()

except Exception as e:
error_message = f"Listing Models: Error in getting Models on Page {page_number}: {e}"
logging.error(error_message, exc_info=True)
return []
if 200 <= r.status_code < 300:
logging.info(f"Listing Models: Status of getting Models on Page {page_number}: {r.status_code}")
all_models = resp["items"]
from aixplain.factories.model_factory.utils import create_model_from_response

model_list = [create_model_from_response(model_info_json) for model_info_json in all_models]
return model_list, resp["total"]
else:
error_message = f"Listing Models Error: Failed to retrieve models. Status Code: {r.status_code}. Error: {resp}"
logging.error(error_message)
raise Exception(error_message)
Loading