Skip to content

Commit

Permalink
feat: add tool result size check
Browse files Browse the repository at this point in the history
  • Loading branch information
jameszyao authored and taskingaijc committed Apr 7, 2024
1 parent c5f8820 commit e171e12
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 27 deletions.
57 changes: 49 additions & 8 deletions backend/app/services/tool/action/openapi_call.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import os
import aiohttp
from aiohttp.client_exceptions import ClientConnectorError, ClientResponseError
import logging
import urllib.parse
import json

from typing import Dict, Optional
from app.models import (
ActionAuthentication,
Expand All @@ -7,10 +13,6 @@
ActionBodyType,
ActionParam,
)
import aiohttp
from aiohttp.client_exceptions import ClientConnectorError
import logging
import urllib.parse

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -141,10 +143,34 @@ async def call_action_api(

async with session.request(method.value, url, **request_kwargs) as response:
response_content_type = response.headers.get("Content-Type", "").lower()

bytes_read = 0
max_size = 64 * 1024
data_chunks = []

# check the size of the response
async for chunk in response.content.iter_any():
bytes_read += len(chunk)
if bytes_read > max_size:
raise ClientResponseError(
response.request_info,
response.history,
message="Response too large",
status=response.status,
)
data_chunks.append(chunk)

data_bytes = b"".join(data_chunks)
if "application/json" in response_content_type:
data = await response.json()
try:
# Assuming the response is JSON and decode here
data = json.loads(data_bytes.decode("utf-8"))
except json.JSONDecodeError:
# Handle non-JSON response or decode error
return {"status": 500, "data": {"error": "Failed to decode the action response"}}
else:
data = {"result": await response.text()}
data = {"result": data_bytes.decode("utf-8")}

if response.status != 200:
error_message = f"API call failed with status {response.status}"
if data:
Expand All @@ -153,7 +179,22 @@ async def call_action_api(
return {"status": response.status, "data": data}

except ClientConnectorError as e:
return {"status": 500, "data": {"error": f"Failed to connect to {url}"}}
return {
"status": 500,
"data": {"error": f"Failed to connect to {url}"},
}

except ClientResponseError as e:
if e.message == "Response too large":
return {
"status": e.status,
"data": {"error": f"Response data is too large. Maximum character length is {max_size}."},
}
else:
return {"status": e.status, "data": {"error": f"API call failed with status {e.status}"}}

except Exception as e:
return {"status": 500, "error": {"error": "Failed to make the API call"}}
return {
"status": 500,
"error": {"error": "Failed to make the API call"},
}
78 changes: 59 additions & 19 deletions backend/app/services/tool/plugin/plugin.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import json
from typing import Dict, List
from app.models import BundleInstance
from app.operators import bundle_instance_ops
import aiohttp
from app.config import CONFIG
from aiohttp.client_exceptions import ClientResponseError

from tkhelper.utils import ResponseWrapper

from app.models import BundleInstance
from app.config import CONFIG
from app.operators import bundle_instance_ops

__all__ = [
"run_plugin",
"get_bundle_registered_dict",
Expand All @@ -27,22 +31,58 @@ async def run_plugin(
bundle_instance_id=bundle_instance_id,
)

async with aiohttp.ClientSession() as session:
response = await session.post(
f"{CONFIG.TASKINGAI_PLUGIN_URL}/v1/execute",
json={
"bundle_id": bundle_instance.bundle_id,
"plugin_id": plugin_id,
"input_params": parameters,
"encrypted_credentials": bundle_instance.encrypted_credentials,
},
)
response_wrapper = ResponseWrapper(response.status, await response.json())
if response.status == 200:
data = response_wrapper.json().get("data")
return {"status": data["status"], "data": data["data"]}

return {"status": response.status, "data": response_wrapper.json().get("error")}
try:
async with aiohttp.ClientSession() as session:
response = await session.post(
f"{CONFIG.TASKINGAI_PLUGIN_URL}/v1/execute",
json={
"bundle_id": bundle_instance.bundle_id,
"plugin_id": plugin_id,
"input_params": parameters,
"encrypted_credentials": bundle_instance.encrypted_credentials,
},
)

bytes_read = 0
max_size = 64 * 1024
data_chunks = []

# check the size of the response
async for chunk in response.content.iter_any():
bytes_read += len(chunk)
if bytes_read > max_size:
raise ClientResponseError(
response.request_info, response.history, message="Response too large", status=response.status
)
data_chunks.append(chunk)

data_bytes = b"".join(data_chunks)
try:
# Assuming the response is JSON and decode here
data_dict = json.loads(data_bytes.decode("utf-8"))
except json.JSONDecodeError:
# Handle non-JSON response or decode error
return {"status": 500, "data": {"error": "Failed to decode the plugin response"}}

response_wrapper = ResponseWrapper(response.status, data_dict)

if response.status == 200:
data = response_wrapper.json().get("data")
return {"status": data["status"], "data": data["data"]}

return {"status": response.status, "data": response_wrapper.json().get("error")}

except ClientResponseError as e:
if "Response too large" in e.message:
return {
"status": e.status,
"data": {"error": f"Response data is too large. Maximum character length is {max_size}."},
}
else:
return {"status": e.status, "data": {"error": f"API call failed with status {e.status}"}}

except Exception as e:
return {"status": 500, "data": {"error": "Failed to execute the plugin"}}


async def get_bundle_registered_dict(
Expand Down

0 comments on commit e171e12

Please sign in to comment.