Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions aixplain/factories/model_factory/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ def create_model_from_response(response: Dict) -> Model:
if "language" in param["name"]:
parameters[param["name"]] = [w["value"] for w in param["values"]]

function = Function(response["function"]["id"])
function_id = response["function"]["id"]
function = Function(function_id)
function_io = FunctionInputOutput.get(function_id, None)
input_params = {param["code"]: param for param in function_io["spec"]["params"]}
output_params = {param["code"]: param for param in function_io["spec"]["output"]}

inputs, temperature = [], None
ModelClass = Model
if function == Function.TEXT_GENERATION:
Expand All @@ -44,15 +49,11 @@ def create_model_from_response(response: Dict) -> Model:
UtilityModelInput(name=param["name"], description=param.get("description", ""), type=DataType(param["dataType"]))
for param in response["params"]
]
input_params = {param["name"]: param for param in response["params"]}

created_at = None
if "createdAt" in response and response["createdAt"]:
created_at = datetime.fromisoformat(response["createdAt"].replace("Z", "+00:00"))
function_id = response["function"]["id"]
function = Function(function_id)
function_io = FunctionInputOutput.get(function_id, None)
input_params = {param["code"]: param for param in function_io["spec"]["params"]}
output_params = {param["code"]: param for param in function_io["spec"]["output"]}

return ModelClass(
response["id"],
Expand Down
7 changes: 3 additions & 4 deletions aixplain/modules/model/utility_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ class UtilityModelInput:
description: Text
type: DataType = DataType.TEXT

def __post_init__(self):
self.validate_type()

def validate_type(self):
def validate(self):
if self.type not in [DataType.TEXT, DataType.BOOLEAN, DataType.NUMBER]:
raise ValueError("Utility Model Input type must be TEXT, BOOLEAN or NUMBER")

Expand Down Expand Up @@ -124,6 +121,8 @@ def validate(self):
self.description = description
if len(self.inputs) == 0:
self.inputs = inputs
for input in self.inputs:
input.validate()
assert self.name and self.name.strip() != "", "Name is required"
assert self.description and self.description.strip() != "", "Description is required"
assert self.code and self.code.strip() != "", "Code is required"
Expand Down