Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions aixplain/modules/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def run(
data: Union[Text, Dict],
name: Text = "model_process",
timeout: float = 300,
parameters: Optional[Dict] = {},
parameters: Optional[Dict] = None,
wait_time: float = 0.5,
) -> Dict:
"""Runs a model call.
Expand All @@ -197,7 +197,7 @@ def run(
data (Union[Text, Dict]): link to the input data
name (Text, optional): ID given to a call. Defaults to "model_process".
timeout (float, optional): total polling time. Defaults to 300.
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
parameters (Dict, optional): optional parameters to the model. Defaults to None.
wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5.

Returns:
Expand All @@ -220,13 +220,13 @@ def run(
response = {"status": "FAILED", "error": msg, "elapsed_time": end - start}
return response

def run_async(self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = {}) -> Dict:
def run_async(self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = None) -> Dict:
"""Runs asynchronously a model call.

Args:
data (Union[Text, Dict]): link to the input data
name (Text, optional): ID given to a call. Defaults to "model_process".
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
parameters (Dict, optional): optional parameters to the model. Defaults to None.

Returns:
dict: polling URL in response
Expand Down
36 changes: 20 additions & 16 deletions aixplain/modules/model/llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def run(
top_p: float = 1.0,
name: Text = "model_process",
timeout: float = 300,
parameters: Optional[Dict] = {},
parameters: Optional[Dict] = None,
wait_time: float = 0.5,
) -> Dict:
"""Synchronously running a Large Language Model (LLM) model.
Expand All @@ -117,21 +117,23 @@ def run(
top_p (float, optional): Top P. Defaults to 1.0.
name (Text, optional): ID given to a call. Defaults to "model_process".
timeout (float, optional): total polling time. Defaults to 300.
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
parameters (Dict, optional): optional parameters to the model. Defaults to None.
wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5.

Returns:
Dict: parsed output from model
"""
start = time.time()
if parameters is None:
parameters = {}
parameters.update(
{
"context": parameters["context"] if "context" in parameters else context,
"prompt": parameters["prompt"] if "prompt" in parameters else prompt,
"history": parameters["history"] if "history" in parameters else history,
"temperature": parameters["temperature"] if "temperature" in parameters else temperature,
"max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens,
"top_p": parameters["top_p"] if "top_p" in parameters else top_p,
"context": parameters.get("context", context),
"prompt": parameters.get("prompt", prompt),
"history": parameters.get("history", history),
"temperature": parameters.get("temperature", temperature),
"max_tokens": parameters.get("max_tokens", max_tokens),
"top_p": parameters.get("top_p", top_p),
}
)
payload = build_payload(data=data, parameters=parameters)
Expand Down Expand Up @@ -160,7 +162,7 @@ def run_async(
max_tokens: int = 128,
top_p: float = 1.0,
name: Text = "model_process",
parameters: Optional[Dict] = {},
parameters: Optional[Dict] = None,
) -> Dict:
"""Runs asynchronously a model call.

Expand All @@ -173,21 +175,23 @@ def run_async(
max_tokens (int, optional): Maximum Generation Tokens. Defaults to 128.
top_p (float, optional): Top P. Defaults to 1.0.
name (Text, optional): ID given to a call. Defaults to "model_process".
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
parameters (Dict, optional): optional parameters to the model. Defaults to None.

Returns:
dict: polling URL in response
"""
url = f"{self.url}/{self.id}"
logging.debug(f"Model Run Async: Start service for {name} - {url}")
if parameters is None:
parameters = {}
parameters.update(
{
"context": parameters["context"] if "context" in parameters else context,
"prompt": parameters["prompt"] if "prompt" in parameters else prompt,
"history": parameters["history"] if "history" in parameters else history,
"temperature": parameters["temperature"] if "temperature" in parameters else temperature,
"max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens,
"top_p": parameters["top_p"] if "top_p" in parameters else top_p,
"context": parameters.get("context", context),
"prompt": parameters.get("prompt", prompt),
"history": parameters.get("history", history),
"temperature": parameters.get("temperature", temperature),
"max_tokens": parameters.get("max_tokens", max_tokens),
"top_p": parameters.get("top_p", top_p),
}
)
payload = build_payload(data=data, parameters=parameters)
Expand Down
7 changes: 5 additions & 2 deletions aixplain/modules/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
import json
import logging
from aixplain.utils.file_utils import _request_with_retry
from typing import Dict, Text, Union
from typing import Dict, Text, Union, Optional


def build_payload(data: Union[Text, Dict], parameters: Dict = {}):
def build_payload(data: Union[Text, Dict], parameters: Optional[Dict] = None):
from aixplain.factories import FileFactory

if parameters is None:
parameters = {}

data = FileFactory.to_link(data)
if isinstance(data, dict):
payload = data
Expand Down