Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions available_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self, yamls: dict):
self.prompts = {}
self.toolboxes = {}
self.model_config = {}
self.namespace_config = {}

# Iterate through all the yaml files and divide them into categories.
# Each file should contain a header like this:
Expand All @@ -52,6 +53,8 @@ def __init__(self, yamls: dict):
add_yaml_to_dict(self.toolboxes, filekey, yaml)
elif filetype == 'model_config':
add_yaml_to_dict(self.model_config, filekey, yaml)
elif filetype == 'namespace_config':
add_yaml_to_dict(self.namespace_config, filekey, yaml)
else:
raise FileTypeException(str(filetype))
except KeyError as err:
Expand All @@ -62,3 +65,40 @@ def __init__(self, yamls: dict):
logging.error(f'{path}: file ID {err.args[0]} is not unique')
except FileTypeException as err:
logging.error(f'{path}: seclab-taskflow-agent file type {err.args[0]} is not supported')

def copy_with_alias(self, alias_dict : dict) -> dict:
def _copy_add_alias_to_dict(original_dict : dict, alias_dict : dict) -> dict:
new_dict = dict(original_dict)
alias_keys = alias_dict.keys()
for k,v in original_dict.items():
for ak in alias_keys:
if k.startswith(ak) and len(k) > len(ak) and k[len(ak)] == '/':
new_key = alias_dict[ak] + k[len(ak):]
new_dict[new_key] = v
Comment on lines +72 to +77
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loop over alias_keys looks inefficient. I think something like this would be better:

Suggested change
alias_keys = alias_dict.keys()
for k,v in original_dict.items():
for ak in alias_keys:
if k.startswith(ak) and len(k) > len(ak) and k[len(ak)] == '/':
new_key = alias_dict[ak] + k[len(ak):]
new_dict[new_key] = v
for k,v in original_dict.items():
ak = k.split('/')[0]
av = alias_dict[ak]
if alias_value:
new_key = av + k[len(ak):]
new_dict[new_key] = v

return new_dict
new_available_tools = AvailableTools({})
new_available_tools.personalities = _copy_add_alias_to_dict(self.personalities, alias_dict)
new_available_tools.taskflows = _copy_add_alias_to_dict(self.taskflows, alias_dict)
new_available_tools.prompts = _copy_add_alias_to_dict(self.prompts, alias_dict)
#toolboxes are looked up after canonicalized
new_available_tools.toolboxes = dict(self.toolboxes)
new_available_tools.model_config = _copy_add_alias_to_dict(self.model_config, alias_dict)
new_available_tools.namespace_config = _copy_add_alias_to_dict(self.namespace_config, alias_dict)
return new_available_tools

def canonicalize_toolboxes(toolboxes : list, alias_dict : dict) -> list:
"""
Toolboxes need to be canonicalized because both personalities and taskflows can use toolboxes with potentially different aliases
"""
out = set()
if not alias_dict:
return toolboxes
for tb in toolboxes:
found_alias = False
for k,v in alias_dict.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above about looping over all the aliases.

if tb.startswith(v) and len(tb) > len(v) and tb[len(v)] == '/':
out.add(k + tb[len(v):])
found_alias = True
if not found_alias:
out.add(tb)
return list(out)
6 changes: 6 additions & 0 deletions configs/namespace_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seclab-taskflow-agent:
version: 1
filetype: namespace_config
filekey: GitHubSecurityLab/seclab-taskflow-agent/configs/namespace_config
namespace_aliases:
GitHubSecurityLab/seclab-taskflow-agent : seclab-ta
33 changes: 23 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from yaml_parser import YamlParser
from agent import TaskAgent
from capi import list_tool_call_models
from available_tools import AvailableTools
from available_tools import AvailableTools, canonicalize_toolboxes

load_dotenv()

Expand Down Expand Up @@ -85,6 +85,15 @@ def parse_prompt_args(available_tools: AvailableTools,
l = args[0].l
return p, t, l, ' '.join(args[0].prompt), help_msg

def _get_namespace_aliases(available_tools, yaml_dict : dict) -> dict:
namespace_config = yaml_dict.get('namespace_config', '')
namespace_aliases = yaml_dict.get('namespace_aliases', {})
if namespace_config:
namespace_config = available_tools.namespace_config.get(namespace_config, {})
namespace_aliases = namespace_aliases | namespace_config.get('namespace_aliases', {})
return namespace_aliases


async def deploy_task_agents(available_tools: AvailableTools,
agents: dict,
prompt: str,
Expand Down Expand Up @@ -114,7 +123,10 @@ async def deploy_task_agents(available_tools: AvailableTools,
# otherwise all agents have the disjunction of all their tools available
for k, v in agents.items():
if v.get('toolboxes', []):
toolboxes += [tb for tb in v['toolboxes'] if tb not in toolboxes]
this_toolboxes = [tb for tb in v['toolboxes']]
namespace_aliases = _get_namespace_aliases(available_tools, v)
this_toolboxes = canonicalize_toolboxes(this_toolboxes, namespace_aliases)
toolboxes += [tb for tb in this_toolboxes if tb not in toolboxes]

# https://openai.github.io/openai-agents-python/ref/model_settings/
parallel_tool_calls = True if os.getenv('MODEL_PARALLEL_TOOL_CALLS') else False
Expand Down Expand Up @@ -229,7 +241,6 @@ async def mcp_session_task(
logging.error(f"RuntimeError in mcp session task: {e}")
except asyncio.CancelledError as e:
logging.error(f"Timeout on main session task: {e}")
pass
finally:
mcp_servers.clear()

Expand Down Expand Up @@ -440,7 +451,9 @@ async def on_handoff_hook(
if not isinstance(model_dict, dict):
raise ValueError(f"Models section of the model_config file {model_config} must be a dictionary")
model_keys = model_dict.keys()

namespace_aliases = _get_namespace_aliases(available_tools, taskflow)
this_available_tools = available_tools.copy_with_alias(namespace_aliases)

for task in taskflow['taskflow']:

task_body = task['task']
Expand All @@ -451,7 +464,7 @@ async def on_handoff_hook(
# can tweak reusable task configurations as they see fit
uses = task_body.get('uses', '')
if uses:
reusable_taskflow = available_tools.taskflows.get(uses)
reusable_taskflow = this_available_tools.taskflows.get(uses)
if reusable_taskflow is None:
raise ValueError(f"No such reusable taskflow: {uses}")
if len(reusable_taskflow['taskflow']) > 1:
Expand All @@ -475,7 +488,7 @@ async def on_handoff_hook(
raise ValueError('shell task and prompt task are mutually exclusive!')
must_complete = task_body.get('must_complete', False)
max_turns = task_body.get('max_steps', DEFAULT_MAX_TURNS)
toolboxes_override = task_body.get('toolboxes', [])
toolboxes_override = canonicalize_toolboxes(task_body.get('toolboxes', []), namespace_aliases)
env = task_body.get('env', {})
repeat_prompt = task_body.get('repeat_prompt', False)
# this will set Agent 'stop_on_first_tool' tool use behavior, which prevents output back to llm
Expand All @@ -500,7 +513,7 @@ def preprocess_prompt(prompt: str, tag: str, kv: dict, kv_subkey=None):

# pre-process the prompt for any prompts
if prompt:
prompt = preprocess_prompt(prompt, 'PROMPTS', available_tools.prompts, 'prompt')
prompt = preprocess_prompt(prompt, 'PROMPTS', this_available_tools.prompts, 'prompt')

# pre-process the prompt for any inputs
if prompt and inputs:
Expand Down Expand Up @@ -587,10 +600,10 @@ async def run_prompts(async_task=False, max_concurrent_tasks=5):
if not agents:
# XXX: deprecate the -p parser for taskflows entirely?
# XXX: probably just adds unneeded parsing complexity
p, _, _, prompt, _ = parse_prompt_args(available_tools, prompt)
p, _, _, prompt, _ = parse_prompt_args(this_available_tools, prompt)
agents.append(p)
for p in agents:
personality = available_tools.personalities.get(p)
personality = this_available_tools.personalities.get(p)
if personality is None:
raise ValueError(f"No such personality: {p}")
resolved_agents[p] = personality
Expand All @@ -599,7 +612,7 @@ async def run_prompts(async_task=False, max_concurrent_tasks=5):
async def _deploy_task_agents(resolved_agents, prompt):
async with semaphore:
result = await deploy_task_agents(
available_tools,
this_available_tools,
# pass agents and prompt by assignment, they change in-loop
resolved_agents,
prompt,
Expand Down
4 changes: 3 additions & 1 deletion personalities/examples/echo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ seclab-taskflow-agent:
filetype: personality
filekey: GitHubSecurityLab/seclab-taskflow-agent/personalities/examples/echo

namespace_aliases:
GitHubSecurityLab/seclab-taskflow-agent : seclab-ta
personality: |
You are a simple echo bot. You use echo tools to echo things.

task: |
Echo user inputs using the echo tools.

toolboxes:
- GitHubSecurityLab/seclab-taskflow-agent/toolboxes/echo
- seclab-ta/toolboxes/echo

5 changes: 3 additions & 2 deletions taskflows/examples/echo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ seclab-taskflow-agent:
version: 1
filetype: taskflow
filekey: GitHubSecurityLab/seclab-taskflow-agent/taskflows/examples/echo

namespace_aliases:
GitHubSecurityLab/seclab-taskflow-agent : seclab-ta
taskflow:
- task:
model: claude-3.5-sonnet
max_steps: 5
must_complete: true
agents:
- GitHubSecurityLab/seclab-taskflow-agent/personalities/examples/echo
- seclab-ta/personalities/examples/echo
user_prompt: |
Hello
- task:
Expand Down
4 changes: 3 additions & 1 deletion taskflows/examples/example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ seclab-taskflow-agent:
filetype: taskflow
filekey: GitHubSecurityLab/seclab-taskflow-agent/taskflows/examples/example

namespace_config: GitHubSecurityLab/seclab-taskflow-agent/configs/namespace_config

taskflow:
- task:
# taskflows can optionally choose any of the support CAPI models for a task
Expand Down Expand Up @@ -35,7 +37,7 @@ taskflow:
# this normally only has the memcache toolbox, but we extend it here with
# the GHSA toolbox
toolboxes:
- GitHubSecurityLab/seclab-taskflow-agent/toolboxes/memcache
- seclab-ta/toolboxes/memcache
- GitHubSecurityLab/seclab-taskflow-agent/toolboxes/codeql
- task:
must_complete: true
Expand Down