Skip to content

Commit

Permalink
feat: supports array as input of plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
DynamesC committed Apr 12, 2024
1 parent c0af7f7 commit e3044b9
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 12 deletions.
3 changes: 2 additions & 1 deletion backend/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ POSTGRES_URL=postgres://postgres:password@localhost:5432/taskingai
POSTGRES_MAX_CONNECTIONS=10
REDIS_URL=redis://localhost:6379/0

# inference
# other TaskingAI services
TASKINGAI_INFERENCE_URL=http://localhost:8002
TASKINGAI_PLUGIN_URL=http://localhost:8003

# secret
AES_ENCRYPTION_KEY=7700b2f9c8dd982dfaddf8b47a92f1d900507ee8ac335f96a64e9ca0f018b195
Expand Down
21 changes: 16 additions & 5 deletions backend/app/models/inference/chat_completion_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,31 @@ class ChatCompletionFunctionCall(BaseModel):
examples=["plus_a_and_b"],
)

class ChatCompletionFunctionParametersPropertyItems(BaseModel):
type: str = Field(
...,
pattern="^(string|number|integer|boolean)$",
description="The type of the item.",
)

class ChatCompletionFunctionParametersProperty(BaseModel):
# type in ["string", "number", "integer", "boolean"]
type: str = Field(
...,
pattern="^(string|number|integer|boolean)$",
pattern="^(string|number|integer|boolean|array|object)$",
description="The type of the parameter.",
)

description: str = Field(
"",
max_length=256,
description="The description of the parameter.",
# items only used in array
items: Optional[ChatCompletionFunctionParametersPropertyItems] = Field(
None,
description="The items of the parameter. Which is only allowed when type is 'array'.",
)

# description should not more than MAXIMUM_PARAMETER_DESCRIPTION_LENGTH characters
description: str = Field("", max_length=512, description="The description of the parameter.")

# optional enum
enum: Optional[List[str]] = Field(
None,
description="The enum list of the parameter. Which is only allowed when type is 'string'.",
Expand Down
13 changes: 11 additions & 2 deletions backend/app/models/tool/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ class ParameterType(str, Enum):
INTEGER = "integer" # int
NUMBER = "number" # float
BOOLEAN = "boolean" # bool
OBJECT = "object" # dict, only for output
ARRAY = "array" # list, only for output
# OBJECT = "object" # dict, only for output
# ARRAY = "array" # list, only for output
STRING_ARRAY = "string_array" # list of str
INTEGER_ARRAY = "integer_array" # list of int
NUMBER_ARRAY = "number_array" # list of float
BOOLEAN_ARRAY = "boolean_array" # list of bool
IMAGE_URL = "image_url" # str
FILE_URL = "file_url" # str

Expand Down Expand Up @@ -43,6 +47,11 @@ def transform_input_schema(bundle_id, plugin_id, plugin_description, input_schem
filtered_value = {k: v for k, v in value.items() if k in ["type", "enum", "description"]}
output_schema["properties"][key] = filtered_value

if "array" in filtered_value.get("type", ""):
array_type = filtered_value["type"].replace("_array", "")
output_schema["properties"][key]["type"] = "array"
output_schema["properties"][key]["items"] = {"type": array_type}

if filtered_value.get("description"):
# handle i18n
output_schema["properties"][key]["description"] = i18n_text(bundle_id, value["description"], "en")
Expand Down
9 changes: 5 additions & 4 deletions backend/app/services/tool/plugin/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ async def sync_plugin_data():
check_http_error(response_wrapper)
response_data = response_wrapper.json()["data"]

# sync i18n
global _i18n_dict
_i18n_dict = response_data["i18n"]

# sort plugins by bundle_id, name
plugins = [Plugin.build(plugin_data) for plugin_data in response_data["plugins"]]
plugins.sort(key=lambda x: (x.bundle_id, x.plugin_id))
Expand All @@ -64,8 +68,6 @@ async def sync_plugin_data():
bundles = [Bundle.build(bundle_data) for bundle_data in response_data["bundles"]]
bundles.sort(key=lambda x: x.bundle_id)

i18n_dict = response_data["i18n"]

for bundle in bundles:
bundle.num_plugins = num_plugins_dict.get(bundle.bundle_id, 0)

Expand All @@ -82,13 +84,12 @@ async def sync_plugin_data():
bundle_plugin_dict[bundle_id].sort(key=lambda x: x.plugin_id)

# update data
global _bundles, _plugins, _bundle_dict, _plugin_dict, _i18n_dict, _bundle_plugin_dict
global _bundles, _plugins, _bundle_dict, _plugin_dict, _bundle_plugin_dict
_bundles = bundles
_plugins = plugins
_bundle_dict = bundle_dict
_plugin_dict = plugin_dict
_bundle_plugin_dict = bundle_plugin_dict
_i18n_dict = i18n_dict

# update checksum
_bundle_checksum = bundle_checksum
Expand Down

0 comments on commit e3044b9

Please sign in to comment.