Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions lagent/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@
StreamingSequential,
)
from .compact_agent import AsyncCompactAgent, estimate_token_count
from .internclaw_agent import (
AsyncEnvAgent,
AsyncPolicyAgent,
InternClawAgent,
)
from .internclaw_agent import AsyncEnvAgent, InternClawAgent

__all__ = [
'Agent',
Expand All @@ -31,6 +27,5 @@
'AsyncCompactAgent',
'estimate_token_count',
'AsyncEnvAgent',
'AsyncPolicyAgent',
'InternClawAgent',
]
26 changes: 16 additions & 10 deletions lagent/agents/aggregator/default_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@

class DefaultAggregator:

def aggregate(self,
messages: Memory,
name: str,
parser: StrParser = None,
system_instruction: str = None,
tools: List[Dict] = None,
) -> Tuple[List[Dict[str, str]], Optional[List[Dict]]]:
def aggregate(
self,
messages: Memory,
name: str,
parser: StrParser = None,
system_instruction: str = None,
tools: List[Dict] = None,
) -> Tuple[List[Dict[str, str]], Optional[List[Dict]]]:
_message = []
messages = messages.get_memory()
if system_instruction:
Expand All @@ -38,12 +39,17 @@ def aggregate(self,
)
)
else:
if len(_message) > 0 and _message[-1]['role'] == 'user':
if (
len(_message) > 0
and _message[-1]['role'] == 'user'
and isinstance(_message[-1]['content'], str)
and isinstance(user_message, str)
):
_message[-1]['content'] += user_message
_message[-1]['extra_info'] = extra_info
else:
_message.append(dict(role='user', content=user_message, extra_info=extra_info))

latest_env_info = None
for message in messages:
if getattr(message, 'env_info', None) is not None:
Expand All @@ -52,7 +58,7 @@ def aggregate(self,
tools_to_use = tools
if latest_env_info and latest_env_info.get("tools"):
tools_to_use = latest_env_info.get("tools")

return _message, tools_to_use

@staticmethod
Expand Down
164 changes: 132 additions & 32 deletions lagent/agents/fc_agent.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import asyncio
import json
import logging
import platform
from copy import deepcopy
from dataclasses import asdict
from typing import Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Protocol, Union

from tenacity import retry, retry_if_result, stop_after_attempt, wait_fixed

from lagent.actions import AsyncActionExecutor
from lagent.hooks import Hook
from lagent.schema import ActionReturn, ActionStatusCode, ActionValidCode, AgentMessage, AgentStatusCode
from lagent.skills.skills import SkillsLoader
from lagent.utils import create_object, truncate_text
from .agent import AsyncAgent

logger = logging.getLogger("lagent.agents.fc_agent")

DEFAULT_TOOL_TEMPLATE = """# Tools

You may call one or more functions to assist with the user query.
Expand All @@ -27,7 +32,9 @@
</tool_call>"""


def get_tool_prompt(actions: list, exclude_arguments: list = None, template: str = DEFAULT_TOOL_TEMPLATE) -> str:
def get_tool_prompt(
actions: list, exclude_arguments: list = None, to_string: bool = True, template: str = DEFAULT_TOOL_TEMPLATE
) -> Union[str, List[dict]]:
exclude_arguments = exclude_arguments or ['session_id']

def _convert_tool_schema(action_description: dict, name_pattern: str = '{}') -> dict:
Expand Down Expand Up @@ -57,60 +64,138 @@ def _convert_tool_schema(action_description: dict, name_pattern: str = '{}') ->
tools.append(_convert_tool_schema(api, f"{action.name}.{{}}"))
else:
tools.append(_convert_tool_schema(action_desc))
if not to_string:
return tools
return template.format(tools='\n'.join([json.dumps(tool, ensure_ascii=False) for tool in tools]))


class FunctionCallAgent(AsyncAgent):
def __init__(
self,
select_agent: Union[Dict, AsyncAgent],
policy_agent: Union[Dict, AsyncAgent],
env_agent: Union[Dict, AsyncAgent],
finish_condition: callable = lambda x, _: x and not x.tool_calls,
compact_agent: Optional[Dict] = None,
consolidate_agent: Optional[Dict] = None,
finish_condition: callable = lambda m, _: m and not m.tool_calls,
max_turn: Optional[int] = None,
initialize_input: bool = True,
name: Optional[str] = None,
):
super().__init__(name=name)
self.select_agent = create_object(select_agent)
self.policy_agent = create_object(policy_agent)
self.env_agent = create_object(env_agent)
self.compact_agent = create_object(compact_agent)
self.consolidate_agent = create_object(consolidate_agent)
self.finish_condition = finish_condition
self.max_turn = max_turn
self.initialize_input = initialize_input

async def forward(self, env_message: AgentMessage, session_id: str | int, **kwargs):
selection_message: AgentMessage = None
async def forward(self, env_message: AgentMessage, **kwargs):
policy_message: AgentMessage = None
current_turn = 0
while (self.finish_condition is None or not self.finish_condition(selection_message, env_message)) and (
if self.initialize_input:
env_message = await self.env_agent(env_message, **kwargs)

while (self.finish_condition is None or not self.finish_condition(policy_message, env_message)) and (
self.max_turn is None or current_turn < self.max_turn
):
selection_message = await self.select_agent(env_message, session_id=session_id, **kwargs)
if selection_message.stream_state == AgentStatusCode.SERVER_ERR:
policy_message = await self.policy_agent(env_message, **kwargs)
if policy_message.stream_state == AgentStatusCode.SERVER_ERR:
raise ValueError("Rollout response error: state is neither completed nor aborted!")
if selection_message.stream_state == AgentStatusCode.SESSION_OUT_OF_LIMIT:
if policy_message.stream_state == AgentStatusCode.SESSION_OUT_OF_LIMIT:
for _ in range(2): # remove the last two messages
self.select_agent.memory.get(session_id).delete(-1)
self.policy_agent.memory.delete(-1)
return AgentMessage(
sender=self.name,
content='Exceeded context length limit',
finish_reason=selection_message.finish_reason,
content=policy_message.content,
finish_reason=policy_message.finish_reason,
)
if selection_message.finish_reason == 'abort':
return AgentMessage(sender=self.name, content='Aborted request', finish_reason='abort')
env_message = await self.env_agent(selection_message, session_id=session_id)
if policy_message.finish_reason == 'abort':
return AgentMessage(sender=self.name, content=policy_message.content, finish_reason='abort')

# Orchestrator manages memory
await self._maybe_manage_memory(policy_message, env_message)

env_message = await self.env_agent(policy_message)
current_turn += 1
if policy_message is not None:
return AgentMessage(sender=self.name, content=policy_message.content, finish_reason='stop')
return AgentMessage(sender=self.name, content="Finished", finish_reason='stop')

async def _maybe_manage_memory(self, policy_message: AgentMessage, env_message: AgentMessage) -> None:
"""Orchestrate compact and consolidate.

Orchestrator calls policy's aggregator to get formatted_messages,
checks should_compact, and if triggered:
1. Runs consolidate_agent (optional)
2. Runs compact_agent to produce summary
3. Injects summary + boundary into env_message
ContextBuilder reads these on the next turn.
"""
if not self.compact_agent:
return

from lagent.agents.compact_agent import estimate_token_count

state = self.get_messages()
formatted_messages, tools = state['policy_agent.messages'], state['policy_agent.tools']
compact_input = AgentMessage(
sender=self.name,
content=formatted_messages,
extra_info={'context_tokens': estimate_token_count(formatted_messages, tools)},
)
if not (hasattr(self.compact_agent, 'should_compact') and self.compact_agent.should_compact(compact_input)):
return

# 1. Consolidate first (preserve info before compacting)
if self.consolidate_agent:
try:
await self.consolidate_agent(compact_input)
self.consolidate_agent.reset(recursive=True)
logger.info("Consolidation completed")
except Exception:
logger.exception("Consolidation failed, continuing with compact")
# 2. Compact — inject summary + boundary into env_message
try:
summary_msg = await self.compact_agent(compact_input)
self.compact_agent.reset(recursive=True)
if summary_msg and summary_msg.content:
if env_message.env_info is None:
env_message.env_info = {}
env_message.env_info['conversation_summary'] = summary_msg.content
env_message.env_info['compact_boundary'] = len(self.policy_agent.memory.memory)
logger.info("Compact summary injected (%d chars)", len(summary_msg.content))
except Exception:
logger.exception("Compact failed")


class MemoryProvider(Protocol):
async def get_info(self) -> dict:
"""Return long-term memory info for EnvAgent's env_info. The content and format are flexible, but should be concise."""
...


class EnvAgent(AsyncAgent):
def __init__(
self,
actions: list,
actions,
skills: Optional[SkillsLoader] = None,
long_term_memory: Optional[MemoryProvider] = None,
stateful_tools: List[str] = None,
max_tool_response_length: int = None,
tool_response_truncate_side: Literal['left', 'right', 'middle'] = 'middle',
action_hooks: List[Union[dict, Hook]] = None,
name: Optional[str] = None,
):
super().__init__(name=name)
self.actions = AsyncActionExecutor(actions, hooks=action_hooks)
if isinstance(actions, AsyncActionExecutor):
for action_hook in action_hooks or []:
actions.register_hook(create_object(action_hook))
self.actions = actions
else:
self.actions = AsyncActionExecutor(actions, hooks=action_hooks)
self.skills = create_object(skills)
self.long_term_memory = create_object(long_term_memory)
self.stateful_tools = stateful_tools or []
self.max_tool_response_length = max_tool_response_length
self.tool_response_truncate_side = tool_response_truncate_side
Expand All @@ -124,41 +209,56 @@ def __init__(
retry_error_callback=lambda retry_state: retry_state.outcome.result(),
)

async def forward(self, selection_message: AgentMessage, session_id: str | int, **kwargs):
if not selection_message.tool_calls:
return AgentMessage(sender=self.name, content='No tool call')
async def get_env_info(self) -> Dict[str, Any]:
env_info = {'skills': '', 'active_skills': '', 'memory': '', 'tools': [], 'runtime': {}}
if self.skills is not None:
env_info['skills'] = await self.skills.build_skills_summary()
always_skills = await self.skills.get_always_skills()
if always_skills:
env_info['active_skills'] = await self.skills.load_skills_for_context(always_skills)
if self.long_term_memory is not None:
env_info['memory'] = await self.long_term_memory.get_info()
if self.actions:
env_info['tools'] = get_tool_prompt(list(self.actions.actions.values()), to_string=False)
for name in ['system', 'machine', 'python_version']:
env_info['runtime'][name] = getattr(platform, name)()
return env_info

async def forward(self, message: AgentMessage, **kwargs):
if not message.tool_calls:
return AgentMessage(sender=self.name, content=message.content, env_info=await self.get_env_info())

tool_responses = await asyncio.gather(
*[
self._retry_mechanism(self.execute_tool)(tool_call, session_id)
for tool_call in selection_message.tool_calls
]
*[self._retry_mechanism(self.execute_tool)(tool_call) for tool_call in message.tool_calls]
)
for tool_call_id, tool_response in zip(selection_message.tool_calls_ids, tool_responses):
for tool_call_id, tool_response in zip(message.tool_calls_ids, tool_responses):
tool_response.tool_call_id = tool_call_id
res = tool_response.format_result()
if self.max_tool_response_length is not None and len(res) > self.max_tool_response_length:
res = truncate_text(res, max_num=self.max_tool_response_length, side=self.tool_response_truncate_side)
tool_response.result = [{'type': 'text', 'content': res}]
return AgentMessage(sender=self.name, content=[asdict(resp) for resp in tool_responses])
return AgentMessage(
sender=self.name, content=[asdict(resp) for resp in tool_responses], env_info=await self.get_env_info()
)

async def execute_tool(self, tool_call: dict, session_id: str | int) -> ActionReturn:
async def execute_tool(self, tool_call: dict) -> ActionReturn:
tool_call = deepcopy(tool_call)
try:
if 'function' in tool_call:
tool_call = tool_call['function']
if tool_call['name'].split('.', 1)[0] not in self.actions:
return ActionReturn(valid=ActionValidCode.INVALID, errmsg=f'Tool {tool_call["name"]} Not Found')
if isinstance(tool_call['arguments'], str):
tool_call['arguments'] = json.loads(tool_call['arguments'])
if tool_call['name'] in self.stateful_tools:
tool_call = deepcopy(tool_call)
tool_call['arguments']['session_id'] = session_id
tool_call['arguments']['session_id'] = str(id(self))
except Exception as e:
return ActionReturn(valid=ActionValidCode.INVALID, errmsg=f'Invalid tool call format: {str(e)}')
tool_response: ActionReturn = (
await self.actions(
AgentMessage(
sender='assistant', content=dict(name=tool_call['name'], parameters=tool_call['arguments'])
),
session_id=session_id,
)
).content
return tool_response
27 changes: 4 additions & 23 deletions lagent/agents/internclaw_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,6 @@ def _convert_tool_schema(action_description: dict, name_pattern: str = '{}') ->
return tools


class AsyncPolicyAgent(AsyncAgent):

async def forward(self, *message, **kwargs):
formatted_messages, tools = self.aggregator.aggregate(
self.memory, self.name, self.output_format, self.template
)
llm_response = await self.llm.chat(formatted_messages, tools=tools, **kwargs)
# message = AgentMessage(
# sender=self.name,
# content=llm_response.get('content') or '',
# tool_calls=llm_response.get('tool_calls') or [],
# reasoning_content=llm_response.get('reasoning_content'),
# )
# return message
return llm_response


class AsyncEnvAgent(AsyncAgent):
def __init__(self, actions, skills: SkillsLoader = None, long_term_memory=None, **kwargs):
super().__init__(**kwargs)
Expand Down Expand Up @@ -168,13 +151,11 @@ async def _inner_func(tool_call):
result_dict['tool_call_id'] = tc.get('id', '')
if resp.valid != ActionValidCode.OPEN:
result_dict['errmsg'] = (
f'Tool Call Error: {resp.errmsg} in tool call '
f'{json.dumps(tc, ensure_ascii=False)}'
f'Tool Call Error: {resp.errmsg} in tool call ' f'{json.dumps(tc, ensure_ascii=False)}'
)
elif resp.state != ActionStatusCode.SUCCESS:
result_dict['errmsg'] = (
f'Tool Call Error: {resp.errmsg} in tool call '
f'{json.dumps(tc, ensure_ascii=False)}'
f'Tool Call Error: {resp.errmsg} in tool call ' f'{json.dumps(tc, ensure_ascii=False)}'
)
if resp.state == ActionStatusCode.ARGS_ERROR:
reward = -1
Expand Down Expand Up @@ -352,7 +333,7 @@ async def main():

# ── 4. Policy agent ──
aggregator = InternClawContextBuilder(workspace, tools=None)
policy = AsyncPolicyAgent(
policy = AsyncAgent(
llm=model,
aggregator=aggregator,
hooks=[logger_hook],
Expand All @@ -374,7 +355,7 @@ async def main():
)

# ── 7. Consolidate agent (standard InternClawAgent) ──
consolidate_policy = AsyncPolicyAgent(
consolidate_policy = AsyncAgent(
name='consolidate_policy',
llm=model,
template=CONSOLIDATION_PROMPT,
Expand Down
Loading