diff --git a/lagent/agents/agent.py b/lagent/agents/agent.py index fceac37a..39daa115 100644 --- a/lagent/agents/agent.py +++ b/lagent/agents/agent.py @@ -183,8 +183,9 @@ def get_messages(self, prefix='', destination=None) -> Dict[str, List[dict]]: if self.aggregator: messages = self.aggregator.aggregate(self.memory, self.name, self.output_format, self.template) if isinstance(messages, tuple): - messages, _ = messages + messages, tools = messages destination[prefix + 'messages'] = messages + destination[prefix + 'tools'] = tools for name, agent in getattr(self, '_agents', {}).items(): if isinstance(agent, Agent): agent.get_messages(destination=destination, prefix=prefix + name + '.') diff --git a/lagent/agents/internclaw_agent.py b/lagent/agents/internclaw_agent.py index 123b2343..0a9c4e20 100644 --- a/lagent/agents/internclaw_agent.py +++ b/lagent/agents/internclaw_agent.py @@ -128,15 +128,24 @@ async def forward(self, message, **kwargs): ) async def _inner_func(tool_call): tool_call = deepcopy(tool_call) + tool_name = tool_call['function'].get('name') try: - if tool_call['function']['name'].split('.', 1)[0] not in self.actions: + if tool_name.split('.', 1)[0] not in self.actions: return ActionReturn( - valid=ActionValidCode.INVALID, errmsg=f"Tool {tool_call['function']['name']} Not Found" + type=tool_name, + args=tool_call['function'].get('arguments'), + valid=ActionValidCode.INVALID, + errmsg=f"Tool {tool_name} Not Found", ) if isinstance(tool_call['function']['arguments'], str): tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments']) except Exception as e: - return ActionReturn(valid=ActionValidCode.INVALID, errmsg=str(e)) + return ActionReturn( + type=tool_name, + args=tool_call['function'].get('arguments'), + valid=ActionValidCode.INVALID, + errmsg=str(e), + ) tool_response: ActionReturn = ( await self.actions( AgentMessage( @@ -151,30 +160,29 @@ async def _inner_func(tool_call): tasks = [_inner_func(tool_call) for tool_call in message.tool_calls] responses = await asyncio.gather(*tasks) - for i, resp in enumerate(responses): - if resp.valid != ActionValidCode.OPEN: - return AgentMessage( - sender=self.name, - content=f'Tool Call Error: {resp.errmsg} in tool call ' - f'{json.dumps(message.tool_calls[i], ensure_ascii=False)}', - ) - if resp.state != ActionStatusCode.SUCCESS: - return AgentMessage( - sender=self.name, - content=f'Tool Call Error: {resp.errmsg} in tool call ' - f'{json.dumps(message.tool_calls[i], ensure_ascii=False)}', - reward=-1 if resp.state == ActionStatusCode.ARGS_ERROR else 0, - ) # Pair each ActionReturn with its tool_call_id for proper LLM API formatting tool_results = [] - for tc, r in zip(message.tool_calls, responses): - result_dict = asdict(r) + reward = 0.0 + for tc, resp in zip(message.tool_calls, responses): + result_dict = asdict(resp) result_dict['tool_call_id'] = tc.get('id', '') + if resp.valid != ActionValidCode.OPEN: + result_dict['errmsg'] = ( + f'Tool Call Error: {resp.errmsg} in tool call ' + f'{json.dumps(tc, ensure_ascii=False)}' + ) + elif resp.state != ActionStatusCode.SUCCESS: + result_dict['errmsg'] = ( + f'Tool Call Error: {resp.errmsg} in tool call ' + f'{json.dumps(tc, ensure_ascii=False)}' + ) + if resp.state == ActionStatusCode.ARGS_ERROR: + reward = -1 tool_results.append(result_dict) return_message = AgentMessage( sender=self.name, content=tool_results, - reward=0.0, + reward=reward, env_info=await self.get_env_info(), ) diff --git a/lagent/utils/util.py b/lagent/utils/util.py index 9b913b24..c57a7de1 100644 --- a/lagent/utils/util.py +++ b/lagent/utils/util.py @@ -9,7 +9,7 @@ import time from functools import partial from logging.handlers import RotatingFileHandler -from typing import Any, Dict, Generator, Iterable, List, Optional, Union, cast +from typing import Any, Dict, Generator, Iterable, List, Optional, Union def load_class_from_string(class_path: str, path=None): @@ -29,29 +29,34 @@ def load_class_from_string(class_path: str, path=None): sys.path.remove(path) +def _is_ray_actor_class(obj_type) -> bool: + try: + from ray.actor import ActorClass + except ImportError: + return False + return isinstance(obj_type, ActorClass) + + def create_object(config: Union[Dict, Any] = None): """Create an instance based on the configuration where 'type' is a preserved key to indicate the class (path). When accepting non-dictionary input, the function degenerates to an identity. """ - from ray.actor import ActorClass - if config is None or not isinstance(config, dict): return config - assert isinstance(config, dict) and 'type' in config + assert 'type' in config config = config.copy() obj_type = config.pop('type') if isinstance(obj_type, str): obj_type = load_class_from_string(obj_type) - if isinstance(obj_type, ActorClass): - obj = cast(ActorClass, obj_type).remote(**config) - elif inspect.isclass(obj_type): - obj = obj_type(**config) - else: - assert callable(obj_type) - obj = partial(obj_type, **config) - return obj + + if _is_ray_actor_class(obj_type): + return obj_type.remote(**config) + if inspect.isclass(obj_type): + return obj_type(**config) + assert callable(obj_type) + return partial(obj_type, **config) async def async_as_completed(futures: Iterable[asyncio.Future]):