In [85]:
import asyncio
import sys
import time
import logging
import inspect
import typing as ty
from functools import wraps

import nest_asyncio
nest_asyncio.apply()
import aioitertools
from langchain.input import get_color_mapping
from langchain.agents import AgentExecutor, Tool
from langchain.agents.agent import AgentAction, AgentFinish
from langchain.callbacks.manager import (
    AsyncCallbackManagerForChainRun, Callbacks, CallbackManager, AsyncCallbackManager
)
from langchain.utilities.asyncio import asyncio_timeout
from langchain.load.dump import dumpd
from langchain.schema import RUN_KEY, RunInfo
sys.path.insert(0, "..")
from codeine.chatbot import build_chat_engine, service_context

level = logging.DEBUG
#level = logging.INFO
logging.basicConfig(level=level)
logger = logging.getLogger(__name__)
logger.setLevel(level)

chat_engine = build_chat_engine()

In [112]:
def rebuild_callback_manager_on_set(
    setter_method: ty.Callable[..., None]
) -> ty.Callable[..., None]:
    """Decorator to force setters to rebuild callback mgr"""
    @wraps(setter_method)
    def wrapper(self: ty.Any, *args: ty.Any, **kwargs: ty.Any) -> None:
        setter_method(self, *args, **kwargs)
        self.build_callback_manager()
    return wrapper

class AgentExecutorIterator:
    def __init__(
        self,
        agent_executor: AgentExecutor,
        inputs: dict[str, str] | str,
        callbacks: Callbacks = None,
        *,
        tags: list[str] | None = None,
        include_run_info: bool = False,
        async_: bool = False
    ):
        """
        Initialize the AgentExecutorIterator with the given AgentExecutor, 
        inputs, and optional callbacks.
        """
        self._agent_executor = agent_executor
        self.inputs = inputs
        self.async_ = async_
        # build callback manager on tags setter
        self._callbacks = callbacks
        self.tags = tags 
        self.include_run_info = include_run_info
        self.run_manager = None
    
    @property
    def inputs(self) -> dict[str, str]:
        return self._inputs
    
    @inputs.setter
    def inputs(self, inputs: dict[str, str] | str) -> None:
        self._inputs = self.agent_executor.prep_inputs(inputs)
    
    @property
    def callbacks(self):
        return self._callbacks
    
    @property
    def tags(self):
        return self._tags
    
    @property
    def agent_executor(self):
        return self._agent_executor
    
    @callbacks.setter
    @rebuild_callback_manager_on_set
    def callbacks(self, callbacks: Callbacks) -> None:
        """When callbacks are changed after __init__, rebuild callback mgr"""
        self._callbacks = callbacks
    
    @tags.setter
    @rebuild_callback_manager_on_set
    def tags(self, tags: list[str] | None) -> None:
        """When tags are changed after __init__, rebuild callback mgr"""
        self._tags = tags
    
    @agent_executor.setter
    @rebuild_callback_manager_on_set
    def agent_executor(self, agent_executor: AgentExecutor) -> None:
        self._agent_executor = agent_executor
        # force re-prep inputs incase agent_executor's prep_inputs fn changed
        self.inputs = self.inputs 
        
    @property
    def callback_manager(self) -> AsyncCallbackManager | CallbackManager:
        return self._callback_manager
    
    def build_callback_manager(self) -> None:
        """
        Create and configure the callback manager based on the current callbacks and tags.
        """
        CallbackMgr = AsyncCallbackManager if self.async_ else CallbackManager
        self._callback_manager = CallbackMgr.configure(
            self.callbacks,
            self.agent_executor.callbacks,
            self.agent_executor.verbose,
            self.tags,
            self.agent_executor.tags
        )        

    @property
    def name_to_tool_map(self):
        return {tool.name: tool for tool in self.agent_executor.tools}
    
    @property
    def color_mapping(self):
        return get_color_mapping(
            [tool.name for tool in self.agent_executor.tools],
            excluded_colors=["green", "red"]
        )
    
    def reset(self):
        """
        Reset the iterator to its initial state, clearing intermediate steps, iterations, and time elapsed.
        """
        logger.debug(f"(Re)setting AgentExecutorIterator to fresh state")
        self.intermediate_steps: list[tuple[AgentAction, str]] = []
        self.iterations = 0
        # maybe better to start these on the first __anext__ call?
        self.time_elapsed = 0.0
        self.start_time = time.time()
        self._final_outputs = None
        
    def update_iterations(self):
        """
        Increment the number of iterations and update the time elapsed.
        """
        self.iterations += 1
        self.time_elapsed = time.time() - self.start_time
        logger.debug(f"Agent Iterations: {self.iterations} ({self.time_elapsed:.2f}s elapsed)")

    def raise_stopiteration(self, output: ty.Any):
        """
        Raise a StopIteration exception with the given output.
        """
        logger.debug("Chain end: stop iteration")
        raise StopIteration(output)
    
    async def raise_stopasynciteration(self, output: ty.Any):
        """
        Raise a StopAsyncIteration exception with the given output.
        Close the timeout context manager.
        """
        logger.debug("Chain end: stop async iteration")
        if self.timeout_manager is not None:
            await self.timeout_manager.__aexit__(None, None, None)
        raise StopAsyncIteration(output)
    
    @property
    def final_outputs(self):
        return self._final_outputs
    
    @final_outputs.setter
    def final_outputs(self, outputs):
        # have access to intermediate steps by design in iterator,
        # so return only outputs may as well always be true.
        final_outputs: dict[str, ty.Any] = self.agent_executor.prep_outputs(
            self.inputs, outputs, return_only_outputs=True
        )
        if self.include_run_info and self.run_manager is not None:
            logger.debug("Assign run key")
            final_outputs[RUN_KEY] = RunInfo(run_id=self.run_manager.run_id)
        self._final_outputs = final_outputs
    
    def __iter__(self):
        logger.debug("Initialising AgentExecutorIterator")
        self.reset()
        self.run_manager = self.callback_manager.on_chain_start(
            dumpd(self.agent_executor),
            self.inputs,
        )
        return self
    
    def __aiter__(self):
        """
        N.B. __aiter__ must be a normal method, so need to initialise async run manager 
        on first __anext__ call where we can await it
        """
        logger.debug("Initialising AgentExecutorIterator (async)")
        self.reset()
        if self.agent_executor.max_execution_time:
            self.timeout_manager = asyncio_timeout(self.agent_executor.max_execution_time)
        else:
            self.timeout_manager = None
        return self

    def _on_first_step(self) -> None:
        """
        Perform any necessary setup for the first step of the synchronous iterator.
        """
        pass
            
    async def _on_first_async_step(self) -> None:
        """
        Perform any necessary setup for the first step of the asynchronous iterator.
        """
        # on first step, need to await callback manager and start async timeout ctxmgr
        if not self.iterations:
            self.run_manager = await self.callback_manager.on_chain_start(
                dumpd(self.agent_executor),
                self.inputs,
            )
            if self.timeout_manager:
                await self.timeout_manager.__aenter__()
    
    def __next__(self) -> dict[str, ty.Any]:
        """
        AgentExecutor               AgentExecutorIterator
        __call__                    (__iter__ ->) __next__
            _call              <=>      _call_next
                _take_next_step             _take_next_step   
        """
        # first step
        if not self.iterations:
            self._on_first_step()
        # N.B. timeout taken care of by "_should_continue" in sync case
        try:
            return self._call_next()
        except (KeyboardInterrupt, Exception) as e:
            self.run_manager.on_chain_error(e)
            raise
            
    async def __anext__(self) -> dict[str, ty.Any]:
        """
        AgentExecutor               AgentExecutorIterator
        acall                       (__aiter__ ->) __anext__
            _acall              <=>     _acall_next
                _atake_next_step            _atake_next_step   
        """
        if not self.iterations:
            await self._on_first_async_step()
        try:
            return await self._acall_next()
        except TimeoutError:
            await self._astop()
        except (KeyboardInterrupt, Exception) as e:
            await self.run_manager.on_chain_error(e)
            raise
        
    def _execute_next_step(self):
        """
        Execute the next step in the chain using the AgentExecutor's _take_next_step method.
        """
        return self.agent_executor._take_next_step(
            self.name_to_tool_map,
            self.color_mapping,
            self.inputs,
            self.intermediate_steps,
            run_manager=self.run_manager,
        )

    async def _execute_next_async_step(self):
        """
        Execute the next step in the chain using the AgentExecutor's _atake_next_step method.
        """
        return await self.agent_executor._atake_next_step(
            self.name_to_tool_map,
            self.color_mapping,
            self.inputs,
            self.intermediate_steps,
            run_manager=self.run_manager,
        )

    def _process_next_step_output(self, next_step_output, run_manager):
        """
        Process the output of the next step, handling AgentFinish and tool return cases.
        """
        logger.debug("Processing output of Agent loop step")
        if isinstance(next_step_output, AgentFinish):
            logger.debug(f"Hit AgentFinish: _return -> on_chain_end -> run final output logic")
            output = self.agent_executor._return(
                next_step_output, self.intermediate_steps, run_manager=run_manager
            )
            if self.run_manager:
                self.run_manager.on_chain_end(output)
            self.final_outputs = output
            return self.final_outputs

        self.intermediate_steps.extend(next_step_output)
        logger.debug("Updated intermediate_steps with step output")

        # Check for tool return
        if len(next_step_output) == 1:
            next_step_action = next_step_output[0]
            tool_return = self.agent_executor._get_tool_return(next_step_action)
            if tool_return is not None:
                output = self.agent_executor._return(
                    tool_return, self.intermediate_steps, run_manager=run_manager
                )
                if self.run_manager:
                    self.run_manager.on_chain_end(output)
                self.final_outputs = output
                return self.final_outputs

        output = {"intermediate_steps": self.intermediate_steps}
        return output

    async def _aprocess_next_step_output(self, next_step_output, run_manager):
        """
        Process the output of the next async step, handling AgentFinish and tool return cases.
        """
        logger.debug("Processing output of async Agent loop step")
        if isinstance(next_step_output, AgentFinish):
            logger.debug(f"Hit AgentFinish: _areturn -> on_chain_end -> run final output logic")
            output = await self.agent_executor._areturn(
                next_step_output, self.intermediate_steps, run_manager=run_manager
            )
            if self.run_manager:
                await self.run_manager.on_chain_end(output)
            self.final_outputs = output
            return self.final_outputs

        self.intermediate_steps.extend(next_step_output)
        logger.debug("Updated intermediate_steps with step output")

        # Check for tool return
        if len(next_step_output) == 1:
            next_step_action = next_step_output[0]
            tool_return = self.agent_executor._get_tool_return(next_step_action)
            if tool_return is not None:
                output = await self.agent_executor._areturn(
                    tool_return, self.intermediate_steps, run_manager=run_manager
                )
                if self.run_manager:
                    await self.run_manager.on_chain_end(output)
                self.final_outputs = output
                return self.final_outputs

        output = {"intermediate_steps": self.intermediate_steps}
        return output
    
    def _stop(self) -> None:
        """
        Stop the iterator and raise a StopIteration exception with the stopped response.
        """
        output = self.agent_executor.agent.return_stopped_response(
            self.agent_executor.early_stopping_method,
            self.intermediate_steps,
            **self.inputs
        )
        output = self.agent_executor._return(
            output, self.intermediate_steps, run_manager=self.run_manager
        )
        self.raise_stopiteration(output)
    
    async def _astop(self) -> None:
        """
        Stop the async iterator and raise a StopAsyncIteration exception with 
        the stopped response.
        """
        output = self.agent_executor.agent.return_stopped_response(
            self.agent_executor.early_stopping_method,
            self.intermediate_steps,
            **self.inputs
        )
        output = await self.agent_executor._areturn(
            output, self.intermediate_steps, run_manager=self.run_manager
        )
        await self.raise_stopasynciteration(output)
        
    def _call_next(self) -> dict[str, ty.Any]:
        """
        Perform a single iteration of the synchronous AgentExecutorIterator.
        """
        # final output already reached: stopiteration (final output)
        if self.final_outputs is not None:
            self.raise_stopiteration(self.final_outputs)
        # timeout/max iterations: stopiteration (stopped response)
        if not self.agent_executor._should_continue(self.iterations, self.time_elapsed):
            self._stop()            
        next_step_output = self._execute_next_step()
        output = self._process_next_step_output(next_step_output, self.run_manager)
        self.update_iterations()
        return output

    async def _acall_next(self) -> dict[str, ty.Any]:
        """
        Perform a single iteration of the asynchronous AgentExecutorIterator.
        """
        # final output already reached: stopiteration (final output)
        if self.final_outputs is not None:
            await self.raise_stopasynciteration(self.final_outputs)
        # timeout/max iterations: stopiteration (stopped response)
        if not self.agent_executor._should_continue(self.iterations, self.time_elapsed):
            await self._astop()       
        next_step_output = await self._execute_next_async_step()
        output = await self._aprocess_next_step_output(next_step_output, self.run_manager)
        self.update_iterations()
        return output

In [115]:
class MyAgentExecutor(AgentExecutor):
    def __call__(
        self,
        inputs: dict[str, str] | ty.Any,
        return_only_outputs: bool = False,
        callbacks: Callbacks = None,
        *,
        tags: list[str] | None = None,
        include_run_info: bool = False,
        iterator: bool = False,
        async_: bool = False,
    ) -> dict[str, ty.Any]:
        if iterator:
            return AgentExecutorIterator(
                self,
                inputs,
                callbacks,
                tags=tags,
                include_run_info=include_run_info,
                async_=async_
            )    
        else:
            return super().__call__(
                inputs,
                return_only_outputs,
                callbacks,
                tags=tags,
                include_run_info=include_run_info
            )

In [105]:
agent_executor.callback_manager

In [116]:
agent_executor = MyAgentExecutor.from_agent_and_tools(
    agent=chat_engine._agent.agent,
    tools=chat_engine._agent.tools,
    callback_manager=chat_engine._agent.callback_manager
)
agent_executor.memory = chat_engine._agent.memory
inspect.signature(agent_executor.__call__)

<Signature (inputs: Union[dict[str, str], Any], return_only_outputs: bool = False, callbacks: Union[List[langchain.callbacks.base.BaseCallbackHandler], langchain.callbacks.base.BaseCallbackManager, NoneType] = None, *, tags: list[str] | None = None, include_run_info: bool = False, iterator: bool = False, async_: bool = False) -> dict[str, typing.Any]>

In [117]:
inputs = "Tell me about the structure of the codeine source code"

for step in agent_executor(inputs=inputs, iterator=True):
    print("*** STEP:")
    print(step)
    print("***")

DEBUG:__main__:Initialising AgentExecutorIterator
DEBUG:__main__:(Re)setting AgentExecutorIterator to fresh state
INFO:llama_index.token_counter.token_counter:> [retrieve] Total LLM token usage: 0 tokens
INFO:llama_index.token_counter.token_counter:> [retrieve] Total embedding token usage: 5 tokens
INFO:llama_index.token_counter.token_counter:> [get_response] Total LLM token usage: 2855 tokens
INFO:llama_index.token_counter.token_counter:> [get_response] Total embedding token usage: 0 tokens
DEBUG:__main__:Processing output of Agent loop step
DEBUG:__main__:Updated intermediate_steps with step output
DEBUG:__main__:Agent Iterations: 1 (6.27s elapsed)


*** STEP:
{'intermediate_steps': [(AgentAction(tool='Codeine Source Code Search', tool_input='Codeine source code structure', log='```json\n{\n    "action": "Codeine Source Code Search",\n    "action_input": "Codeine source code structure"\n}\n```'), 'The Codeine source code is structured into multiple files, including presets.py, chatbot.py, utils.py, README.md, and LICENSE. The presets.py file contains a theme for the Gradio frontend, while chatbot.py contains code for building a chat engine. The utils.py file contains various utility functions, including one for converting Markdown to HTML with syntax highlighting. The README.md file provides installation instructions and information about the project, while the LICENSE file outlines the permissions and conditions for using the Codeine source code.')]}
***


DEBUG:__main__:Processing output of Agent loop step
DEBUG:__main__:Hit AgentFinish: _return -> on_chain_end -> run final output logic
DEBUG:__main__:Agent Iterations: 2 (10.42s elapsed)
DEBUG:__main__:Chain end: stop iteration


*** STEP:
{'output': 'The Codeine source code is structured into multiple files, including presets.py, chatbot.py, utils.py, README.md, and LICENSE. The presets.py file contains a theme for the Gradio frontend, while chatbot.py contains code for building a chat engine. The utils.py file contains various utility functions, including one for converting Markdown to HTML with syntax highlighting. The README.md file provides installation instructions and information about the project, while the LICENSE file outlines the permissions and conditions for using the Codeine source code.'}
***


In [118]:
inputs = "Tell me about the structure of the codeine source code"
async_mae_iter = agent_executor(inputs=inputs, iterator=True, async_=True)
async_mae_iter.inputs = "Tell me about ze structure of the codeine source code"
async_mae_iter.agent_executor = async_mae_iter.agent_executor
async for step in async_mae_iter:
    print("*** STEP:")
    print(step)
    print("***")

DEBUG:__main__:Initialising AgentExecutorIterator (async)
DEBUG:__main__:(Re)setting AgentExecutorIterator to fresh state
INFO:openai:message='OpenAI API response' path=https://api.openai.com/v1/chat/completions processing_ms=1346 request_id=c1a286a1fd70fcd75dfb2507dd600199 response_code=200
INFO:llama_index.token_counter.token_counter:> [retrieve] Total LLM token usage: 0 tokens
INFO:llama_index.token_counter.token_counter:> [retrieve] Total embedding token usage: 5 tokens
INFO:llama_index.token_counter.token_counter:> [get_response] Total LLM token usage: 2855 tokens
INFO:llama_index.token_counter.token_counter:> [get_response] Total embedding token usage: 0 tokens
DEBUG:__main__:Processing output of async Agent loop step
DEBUG:__main__:Updated intermediate_steps with step output
DEBUG:__main__:Agent Iterations: 1 (6.97s elapsed)


*** STEP:
{'intermediate_steps': [(AgentAction(tool='Codeine Source Code Search', tool_input='Codeine source code structure', log='```json\n{\n    "action": "Codeine Source Code Search",\n    "action_input": "Codeine source code structure"\n}\n```'), 'The Codeine source code is structured into multiple files, including presets.py, chatbot.py, utils.py, README.md, and LICENSE. The presets.py file contains a theme for the Gradio frontend, while chatbot.py contains code for building a chat engine. The utils.py file contains various utility functions, including one for converting Markdown to HTML with syntax highlighting. The README.md file provides installation instructions and information about the project, while the LICENSE file outlines the permissions and conditions for using the Codeine source code.')]}
***


INFO:openai:message='OpenAI API response' path=https://api.openai.com/v1/chat/completions processing_ms=3573 request_id=6f33d6c5f85b077c81bde7735001ced3 response_code=200
DEBUG:__main__:Processing output of async Agent loop step
DEBUG:__main__:Hit AgentFinish: _areturn -> on_chain_end -> run final output logic
DEBUG:__main__:Agent Iterations: 2 (14.62s elapsed)
DEBUG:__main__:Chain end: stop async iteration


*** STEP:
{'output': 'The Codeine source code is structured into multiple files, including presets.py, chatbot.py, utils.py, README.md, and LICENSE. The presets.py file contains a theme for the Gradio frontend, while chatbot.py contains code for building a chat engine. The utils.py file contains various utility functions, including one for converting Markdown to HTML with syntax highlighting. The README.md file provides installation instructions and information about the project, while the LICENSE file outlines the permissions and conditions for using the Codeine source code.'}
***


Two pieces of code on which design/refactor is based:

Original AgentExecutor (non-iterator version) for which we want to mimic the logic
https://github.com/hwchase17/langchain/blob/2da1aab50b43c63c7a9a9553b7290230c44604bc/langchain/agents/agent.py#L620

The inherited `__call__` and `acall` methods from Chain:
https://github.com/hwchase17/langchain/blob/22af93d8516a4ecc05e2c814ad5660c0b6427625/langchain/chains/base.py#L126