diff --git a/lagent/agents/env_agent.py b/lagent/agents/env_agent.py index 4198468..f74a328 100644 --- a/lagent/agents/env_agent.py +++ b/lagent/agents/env_agent.py @@ -29,9 +29,10 @@ def __init__( enable_repeated_tool_call_penalty: bool = False, action_hooks: Optional[List[Union[dict, Hook]]] = None, name: Optional[str] = None, + **kwargs ): super().__init__( - actions, stateful_tools, max_tool_response_length, tool_response_truncate_side, action_hooks, name + actions, stateful_tools, max_tool_response_length, tool_response_truncate_side, action_hooks, name, **kwargs ) self.judger: AsyncAgent = create_object(judger) # scoring rule settings diff --git a/lagent/agents/fc_agent.py b/lagent/agents/fc_agent.py index 386ac78..7ff38cf 100644 --- a/lagent/agents/fc_agent.py +++ b/lagent/agents/fc_agent.py @@ -188,8 +188,9 @@ def __init__( tool_response_truncate_side: Literal['left', 'right', 'middle'] = 'middle', action_hooks: List[Union[dict, Hook]] = None, name: Optional[str] = None, + **kwargs ): - super().__init__(name=name) + super().__init__(name=name, **kwargs) if isinstance(actions, AsyncActionExecutor): for action_hook in action_hooks or []: actions.register_hook(create_object(action_hook))