Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aixplain/modules/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def run(
data: Union[Text, Dict],
name: Text = "model_process",
timeout: float = 300,
parameters: Dict = {},
parameters: Optional[Dict] = {},
wait_time: float = 0.5,
) -> Dict:
"""Runs a model call.
Expand Down Expand Up @@ -220,7 +220,7 @@ def run(
response = {"status": "FAILED", "error": msg, "elapsed_time": end - start}
return response

def run_async(self, data: Union[Text, Dict], name: Text = "model_process", parameters: Dict = {}) -> Dict:
def run_async(self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = {}) -> Dict:
"""Runs asynchronously a model call.

Args:
Expand Down
4 changes: 2 additions & 2 deletions aixplain/modules/model/llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def run(
top_p: float = 1.0,
name: Text = "model_process",
timeout: float = 300,
parameters: Dict = {},
parameters: Optional[Dict] = {},
wait_time: float = 0.5,
) -> Dict:
"""Synchronously running a Large Language Model (LLM) model.
Expand Down Expand Up @@ -160,7 +160,7 @@ def run_async(
max_tokens: int = 128,
top_p: float = 1.0,
name: Text = "model_process",
parameters: Dict = {},
parameters: Optional[Dict] = {},
) -> Dict:
"""Runs asynchronously a model call.

Expand Down
2 changes: 1 addition & 1 deletion aixplain/modules/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def call_run_endpoint(url: Text, api_key: Text, payload: Dict) -> Dict:
else:
response = resp
else:
resp = resp["error"] if "error" in resp else resp
resp = resp["error"] if isinstance(resp, dict) and "error" in resp else resp
if r.status_code == 401:
error = f"Unauthorized API key: Please verify the spelling of the API key and its current validity. Details: {resp}"
elif 460 <= r.status_code < 470:
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/model/run_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def pytest_generate_tests(metafunc):
four_weeks_ago = datetime.now(timezone.utc) - timedelta(weeks=4)
models = ModelFactory.list(function=Function.TEXT_GENERATION)["results"]

predefined_models = ["Groq Llama 3 70B", "Chat GPT 3.5", "GPT-4o", "GPT 4 (32k)"]
predefined_models = ["Groq Llama 3 70B", "Chat GPT 3.5", "GPT-4o"]
recent_models = [model for model in models if model.created_at and model.created_at >= four_weeks_ago]
combined_models = recent_models + [
ModelFactory.list(query=model, function=Function.TEXT_GENERATION)["results"][0] for model in predefined_models
Expand Down