Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lagent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,9 @@ def get_messages(self, prefix='', destination=None) -> Dict[str, List[dict]]:
if self.aggregator:
messages = self.aggregator.aggregate(self.memory, self.name, self.output_format, self.template)
if isinstance(messages, tuple):
messages, _ = messages
messages, tools = messages
destination[prefix + 'messages'] = messages
destination[prefix + 'tools'] = tools
for name, agent in getattr(self, '_agents', {}).items():
if isinstance(agent, Agent):
agent.get_messages(destination=destination, prefix=prefix + name + '.')
Expand Down
48 changes: 28 additions & 20 deletions lagent/agents/internclaw_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,24 @@ async def forward(self, message, **kwargs):
)
async def _inner_func(tool_call):
tool_call = deepcopy(tool_call)
tool_name = tool_call['function'].get('name')
try:
if tool_call['function']['name'].split('.', 1)[0] not in self.actions:
if tool_name.split('.', 1)[0] not in self.actions:
return ActionReturn(
valid=ActionValidCode.INVALID, errmsg=f"Tool {tool_call['function']['name']} Not Found"
type=tool_name,
args=tool_call['function'].get('arguments'),
valid=ActionValidCode.INVALID,
errmsg=f"Tool {tool_name} Not Found",
)
if isinstance(tool_call['function']['arguments'], str):
tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments'])
except Exception as e:
return ActionReturn(valid=ActionValidCode.INVALID, errmsg=str(e))
return ActionReturn(
type=tool_name,
args=tool_call['function'].get('arguments'),
valid=ActionValidCode.INVALID,
errmsg=str(e),
)
tool_response: ActionReturn = (
await self.actions(
AgentMessage(
Expand All @@ -151,30 +160,29 @@ async def _inner_func(tool_call):

tasks = [_inner_func(tool_call) for tool_call in message.tool_calls]
responses = await asyncio.gather(*tasks)
for i, resp in enumerate(responses):
if resp.valid != ActionValidCode.OPEN:
return AgentMessage(
sender=self.name,
content=f'Tool Call Error: {resp.errmsg} in tool call '
f'{json.dumps(message.tool_calls[i], ensure_ascii=False)}',
)
if resp.state != ActionStatusCode.SUCCESS:
return AgentMessage(
sender=self.name,
content=f'Tool Call Error: {resp.errmsg} in tool call '
f'{json.dumps(message.tool_calls[i], ensure_ascii=False)}',
reward=-1 if resp.state == ActionStatusCode.ARGS_ERROR else 0,
)
# Pair each ActionReturn with its tool_call_id for proper LLM API formatting
tool_results = []
for tc, r in zip(message.tool_calls, responses):
result_dict = asdict(r)
reward = 0.0
for tc, resp in zip(message.tool_calls, responses):
result_dict = asdict(resp)
result_dict['tool_call_id'] = tc.get('id', '')
if resp.valid != ActionValidCode.OPEN:
result_dict['errmsg'] = (
f'Tool Call Error: {resp.errmsg} in tool call '
f'{json.dumps(tc, ensure_ascii=False)}'
)
elif resp.state != ActionStatusCode.SUCCESS:
result_dict['errmsg'] = (
f'Tool Call Error: {resp.errmsg} in tool call '
f'{json.dumps(tc, ensure_ascii=False)}'
)
if resp.state == ActionStatusCode.ARGS_ERROR:
reward = -1
tool_results.append(result_dict)
return_message = AgentMessage(
sender=self.name,
content=tool_results,
reward=0.0,
reward=reward,
env_info=await self.get_env_info(),
)

Expand Down
29 changes: 17 additions & 12 deletions lagent/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
from functools import partial
from logging.handlers import RotatingFileHandler
from typing import Any, Dict, Generator, Iterable, List, Optional, Union, cast
from typing import Any, Dict, Generator, Iterable, List, Optional, Union


def load_class_from_string(class_path: str, path=None):
Expand All @@ -29,29 +29,34 @@ def load_class_from_string(class_path: str, path=None):
sys.path.remove(path)


def _is_ray_actor_class(obj_type) -> bool:
try:
from ray.actor import ActorClass
except ImportError:
return False
return isinstance(obj_type, ActorClass)


def create_object(config: Union[Dict, Any] = None):
"""Create an instance based on the configuration where 'type' is a
preserved key to indicate the class (path). When accepting non-dictionary
input, the function degenerates to an identity.
"""
from ray.actor import ActorClass

if config is None or not isinstance(config, dict):
return config
assert isinstance(config, dict) and 'type' in config
assert 'type' in config

config = config.copy()
obj_type = config.pop('type')
if isinstance(obj_type, str):
obj_type = load_class_from_string(obj_type)
if isinstance(obj_type, ActorClass):
obj = cast(ActorClass, obj_type).remote(**config)
elif inspect.isclass(obj_type):
obj = obj_type(**config)
else:
assert callable(obj_type)
obj = partial(obj_type, **config)
return obj

if _is_ray_actor_class(obj_type):
return obj_type.remote(**config)
if inspect.isclass(obj_type):
return obj_type(**config)
assert callable(obj_type)
return partial(obj_type, **config)


async def async_as_completed(futures: Iterable[asyncio.Future]):
Expand Down