diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index d745885c..5a8d1503 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -30,7 +30,12 @@ def create_model_from_response(response: Dict) -> Model: if "language" in param["name"]: parameters[param["name"]] = [w["value"] for w in param["values"]] - function = Function(response["function"]["id"]) + function_id = response["function"]["id"] + function = Function(function_id) + function_io = FunctionInputOutput.get(function_id, None) + input_params = {param["code"]: param for param in function_io["spec"]["params"]} + output_params = {param["code"]: param for param in function_io["spec"]["output"]} + inputs, temperature = [], None ModelClass = Model if function == Function.TEXT_GENERATION: @@ -44,15 +49,11 @@ def create_model_from_response(response: Dict) -> Model: UtilityModelInput(name=param["name"], description=param.get("description", ""), type=DataType(param["dataType"])) for param in response["params"] ] + input_params = {param["name"]: param for param in response["params"]} created_at = None if "createdAt" in response and response["createdAt"]: created_at = datetime.fromisoformat(response["createdAt"].replace("Z", "+00:00")) - function_id = response["function"]["id"] - function = Function(function_id) - function_io = FunctionInputOutput.get(function_id, None) - input_params = {param["code"]: param for param in function_io["spec"]["params"]} - output_params = {param["code"]: param for param in function_io["spec"]["output"]} return ModelClass( response["id"], diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py index a9fe514e..f3f597ef 100644 --- a/aixplain/modules/model/utility_model.py +++ b/aixplain/modules/model/utility_model.py @@ -35,10 +35,7 @@ class UtilityModelInput: description: Text type: DataType = DataType.TEXT - def __post_init__(self): - self.validate_type() - - def validate_type(self): + def validate(self): if self.type not in [DataType.TEXT, DataType.BOOLEAN, DataType.NUMBER]: raise ValueError("Utility Model Input type must be TEXT, BOOLEAN or NUMBER") @@ -124,6 +121,8 @@ def validate(self): self.description = description if len(self.inputs) == 0: self.inputs = inputs + for input in self.inputs: + input.validate() assert self.name and self.name.strip() != "", "Name is required" assert self.description and self.description.strip() != "", "Description is required" assert self.code and self.code.strip() != "", "Code is required"