Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: supports array as input of plugin #101

Merged
merged 1 commit into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ POSTGRES_URL=postgres://postgres:password@localhost:5432/taskingai
POSTGRES_MAX_CONNECTIONS=10
REDIS_URL=redis://localhost:6379/0

# inference
# other TaskingAI services
TASKINGAI_INFERENCE_URL=http://localhost:8002
TASKINGAI_PLUGIN_URL=http://localhost:8003

# secret
AES_ENCRYPTION_KEY=7700b2f9c8dd982dfaddf8b47a92f1d900507ee8ac335f96a64e9ca0f018b195
Expand Down
21 changes: 16 additions & 5 deletions backend/app/models/inference/chat_completion_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,31 @@ class ChatCompletionFunctionCall(BaseModel):
examples=["plus_a_and_b"],
)

class ChatCompletionFunctionParametersPropertyItems(BaseModel):
type: str = Field(
...,
pattern="^(string|number|integer|boolean)$",
description="The type of the item.",
)

class ChatCompletionFunctionParametersProperty(BaseModel):
# type in ["string", "number", "integer", "boolean"]
type: str = Field(
...,
pattern="^(string|number|integer|boolean)$",
pattern="^(string|number|integer|boolean|array|object)$",
description="The type of the parameter.",
)

description: str = Field(
"",
max_length=256,
description="The description of the parameter.",
# items only used in array
items: Optional[ChatCompletionFunctionParametersPropertyItems] = Field(
None,
description="The items of the parameter. Which is only allowed when type is 'array'.",
)

# description should not more than MAXIMUM_PARAMETER_DESCRIPTION_LENGTH characters
description: str = Field("", max_length=512, description="The description of the parameter.")

# optional enum
enum: Optional[List[str]] = Field(
None,
description="The enum list of the parameter. Which is only allowed when type is 'string'.",
Expand Down
13 changes: 11 additions & 2 deletions backend/app/models/tool/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ class ParameterType(str, Enum):
INTEGER = "integer" # int
NUMBER = "number" # float
BOOLEAN = "boolean" # bool
OBJECT = "object" # dict, only for output
ARRAY = "array" # list, only for output
# OBJECT = "object" # dict, only for output
# ARRAY = "array" # list, only for output
STRING_ARRAY = "string_array" # list of str
INTEGER_ARRAY = "integer_array" # list of int
NUMBER_ARRAY = "number_array" # list of float
BOOLEAN_ARRAY = "boolean_array" # list of bool
IMAGE_URL = "image_url" # str
FILE_URL = "file_url" # str

Expand Down Expand Up @@ -43,6 +47,11 @@ def transform_input_schema(bundle_id, plugin_id, plugin_description, input_schem
filtered_value = {k: v for k, v in value.items() if k in ["type", "enum", "description"]}
output_schema["properties"][key] = filtered_value

if "array" in filtered_value.get("type", ""):
array_type = filtered_value["type"].replace("_array", "")
output_schema["properties"][key]["type"] = "array"
output_schema["properties"][key]["items"] = {"type": array_type}

if filtered_value.get("description"):
# handle i18n
output_schema["properties"][key]["description"] = i18n_text(bundle_id, value["description"], "en")
Expand Down
9 changes: 5 additions & 4 deletions backend/app/services/tool/plugin/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ async def sync_plugin_data():
check_http_error(response_wrapper)
response_data = response_wrapper.json()["data"]

# sync i18n
global _i18n_dict
_i18n_dict = response_data["i18n"]

# sort plugins by bundle_id, name
plugins = [Plugin.build(plugin_data) for plugin_data in response_data["plugins"]]
plugins.sort(key=lambda x: (x.bundle_id, x.plugin_id))
Expand All @@ -64,8 +68,6 @@ async def sync_plugin_data():
bundles = [Bundle.build(bundle_data) for bundle_data in response_data["bundles"]]
bundles.sort(key=lambda x: x.bundle_id)

i18n_dict = response_data["i18n"]

for bundle in bundles:
bundle.num_plugins = num_plugins_dict.get(bundle.bundle_id, 0)

Expand All @@ -82,13 +84,12 @@ async def sync_plugin_data():
bundle_plugin_dict[bundle_id].sort(key=lambda x: x.plugin_id)

# update data
global _bundles, _plugins, _bundle_dict, _plugin_dict, _i18n_dict, _bundle_plugin_dict
global _bundles, _plugins, _bundle_dict, _plugin_dict, _bundle_plugin_dict
_bundles = bundles
_plugins = plugins
_bundle_dict = bundle_dict
_plugin_dict = plugin_dict
_bundle_plugin_dict = bundle_plugin_dict
_i18n_dict = i18n_dict

# update checksum
_bundle_checksum = bundle_checksum
Expand Down