Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions available_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class AvailableTools:
"""
This class is used for storing dictionaries of all the available
personalities, taskflows, and prompts.
"""
def __init__(self, personalities: dict, taskflows: dict, prompts: dict):
self.personalities = personalities
self.taskflows = taskflows
self.prompts = prompts
37 changes: 19 additions & 18 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from yaml_parser import YamlParser
from agent import TaskAgent
from capi import list_tool_call_models
from available_tools import AvailableTools

load_dotenv()

Expand All @@ -48,9 +49,8 @@
MAX_API_RETRY = 5
MCP_CLEANUP_TIMEOUT = 5

def parse_prompt_args(user_prompt: str | None = None):
available_personalities = YamlParser('personalities').get_yaml_dict()
available_taskflows = YamlParser('taskflows').get_yaml_dict()
def parse_prompt_args(available_tools: AvailableTools,
user_prompt: str | None = None):
parser = argparse.ArgumentParser(add_help=False, description="SecLab Taskflow Agent")
parser.prog = ''
group = parser.add_mutually_exclusive_group()
Expand All @@ -61,10 +61,10 @@ def parse_prompt_args(user_prompt: str | None = None):
#parser.add_argument('remainder', nargs=argparse.REMAINDER, help="Remaining args")
help_msg = parser.format_help()
help_msg += "\nAvailable Personalities:\n\n"
for k in available_personalities:
for k in available_tools.personalities:
help_msg += f"`{k}`\n"
help_msg += "\nAvailable Taskflows:\n\n"
for k in available_taskflows:
for k in available_tools.taskflows:
help_msg += f"`{k}`\n"
help_msg += "\nExamples:\n\n"
help_msg += "`-p assistant explain modems to me please`\n"
Expand Down Expand Up @@ -372,11 +372,8 @@ async def _run_streamed():
logging.error(f"Exception in mcp server cleanup task: {e}")


async def main(p: str | None, t: str | None, prompt: str | None):

available_personalities = YamlParser('personalities').get_yaml_dict()
available_taskflows = YamlParser('taskflows').get_yaml_dict()
available_prompts = YamlParser('prompts').get_yaml_dict(dir_namespace=True)
async def main(available_tools: AvailableTools,
p: str | None, t: str | None, prompt: str | None):
last_mcp_tool_results = [] # XXX: memleaky

async def on_tool_end_hook(
Expand All @@ -399,7 +396,7 @@ async def on_handoff_hook(
await render_model_output(f"\n** 🤖🤝 Agent Handoff: {source.name} -> {agent.name}\n")

if p:
personality = available_personalities.get(p)
personality = available_tools.personalities.get(p)
if personality is None:
raise ValueError("No such personality!")

Expand All @@ -412,7 +409,7 @@ async def on_handoff_hook(

if t:

taskflow = available_taskflows.get(t)
taskflow = available_tools.taskflows.get(t)
if taskflow is None:
raise ValueError("No such taskflow!")

Expand All @@ -431,7 +428,7 @@ async def on_handoff_hook(
# can tweak reusable task configurations as they see fit
uses = task_body.get('uses', '')
if uses:
reusable_taskflow = available_taskflows.get(uses)
reusable_taskflow = available_tools.taskflows.get(uses)
if reusable_taskflow is None:
raise ValueError(f"No such reusable taskflow: {uses}")
if len(reusable_taskflow['taskflow']) > 1:
Expand Down Expand Up @@ -479,7 +476,7 @@ def preprocess_prompt(prompt: str, tag: str, kv: dict, kv_subkey=None):

# pre-process the prompt for any prompts
if prompt:
prompt = preprocess_prompt(prompt, 'PROMPTS', available_prompts, 'prompt')
prompt = preprocess_prompt(prompt, 'PROMPTS', available_tools.prompts, 'prompt')

# pre-process the prompt for any inputs
if prompt and inputs:
Expand Down Expand Up @@ -566,10 +563,10 @@ async def run_prompts(async_task=False, max_concurrent_tasks=5):
if not agents:
# XXX: deprecate the -p parser for taskflows entirely?
# XXX: probably just adds unneeded parsing complexity
p, _, _, prompt, _ = parse_prompt_args(prompt)
p, _, _, prompt, _ = parse_prompt_args(available_tools, prompt)
agents.append(p)
for p in agents:
personality = available_personalities.get(p)
personality = available_tools.personalities.get(p)
if personality is None:
raise ValueError(f"No such personality: {p}")
resolved_agents[p] = personality
Expand Down Expand Up @@ -628,8 +625,12 @@ async def _deploy_task_agents(resolved_agents, prompt):
break

if __name__ == '__main__':
available_tools = AvailableTools(
personalities = YamlParser('personalities').get_yaml_dict(),
taskflows = YamlParser('taskflows').get_yaml_dict(),
prompts = YamlParser('prompts').get_yaml_dict(dir_namespace=True))

p, t, l, user_prompt, help_msg = parse_prompt_args()
p, t, l, user_prompt, help_msg = parse_prompt_args(available_tools)

if l:
tool_models = list_tool_call_models(os.getenv('COPILOT_TOKEN'))
Expand All @@ -641,4 +642,4 @@ async def _deploy_task_agents(resolved_agents, prompt):
print(help_msg)
sys.exit(1)

asyncio.run(main(p, t, user_prompt), debug=True)
asyncio.run(main(available_tools, p, t, user_prompt), debug=True)