Below is a step‑by‑step explanation of what’s going on when RunnableWithMessageHistory sets up this snippet:

python
Copy
Edit
history_chain: Runnable = RunnableLambda(
    self._enter_history, self._aenter_history
).with_config(run_name="load_history")

messages_key = history_messages_key or input_messages_key
if messages_key:
    history_chain = (
        RunnablePassthrough
        .assign(**{messages_key: history_chain})
        .with_config(run_name="insert_history")
    )
Overall Goal
RunnableWithMessageHistory has two major jobs:

Load chat history from the user’s storage (like a database) and place it into the incoming dictionary.
After the main chain finishes, it appends the new AI message and user message back to the chat history storage.
The snippet you asked about handles (1) loading the chat history into the dictionary. Specifically, it:

Creates a history_chain (RunnableLambda) that fetches a list of messages from your stored conversation.
Then merges those messages into the pipeline’s dictionary under a key like "chat_history" (or whichever messages_key is given).
That way, the next node in the chain has the entire conversation so far.

Step by Step
history_chain = RunnableLambda(self._enter_history, self._aenter_history):

RunnableLambda is a Runnable that calls either the sync function (_enter_history) or the async function (_aenter_history) depending on whether you do .invoke() or .ainvoke().

Inside _enter_history/_aenter_history, RunnableWithMessageHistory does something like:

python
Copy
Edit
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
messages = hist.messages.copy()
...
# Possibly also merges user messages into the list
return messages
In other words, it returns a list of all messages up to now from your conversation storage. This “list of messages” is the RunnableLambda output.

if messages_key: (like "chat_history" or "history") – we do:

python
Copy
Edit
history_chain = RunnablePassthrough.assign(**{messages_key: history_chain})
RunnablePassthrough.assign(...) means:

“Take the input dictionary, pass it through untouched, except add or overwrite a new key–value pair: messages_key = <the output of the sub-runnable>.”
So effectively, we have:

The pipeline input dictionary (say, {"input": "User question"}).
We run history_chain in parallel to produce the list of stored messages, e.g. [SystemMessage("..."), HumanMessage("...")].
Then we put that list as dict[messages_key] (like dict["chat_history"] = [...]).
We pass the final dictionary on to the next step in the chain.
Hence, the final dictionary going forward might look like:

python
Copy
Edit
{
    "input": "User question",
    "chat_history": [
       SystemMessage(...),
       HumanMessage(...),
       ...
    ]
}
Why a two-step structure?
history_chain alone returns list_of_messages. But the chain system expects to keep the original dictionary plus a new field.
RunnablePassthrough.assign(...) merges those two results together.
So the pipeline says:
“Get the conversation from history_chain.”
“Store it under a dictionary key like chat_history.”
“Pass that dictionary forward.”
Final Flow in RunnableWithMessageHistory
Putting it all together in RunnableWithMessageHistory:

Before the main chain runs, we do history_chain so that the dictionary has the full conversation.
The chain can see chat_history or history and incorporate it into the prompt.
After the chain completes, _exit_history merges the brand-new AI message (and user’s message if needed) back into permanent storage.
Essentially:

“Load” messages → chat_history.
Call user’s chain with that chat_history.
On end, “save” new messages.
That’s how the snippet works in-depth.

You said:
history_chain = RunnablePassthrough.assign(**{messages_key: history_chain}) And how is it possible these runnables are not executed right away but rather only when the chain is invoked?
Reasoned for a few seconds
In LangChain’s Runnable architecture, creating or chaining runnables is a declarative (lazy) process, whereas actually running them only happens when you eventually call something like .invoke(...), .ainvoke(...), or .stream(...). In other words:

When you write code like:

python
Copy
Edit
history_chain = RunnablePassthrough.assign(**{messages_key: history_chain})
you’re building (or extending) a “pipeline” object. Each Runnable references the next node or transformation. No work is done at this time; you’re simply composing runnables.

At some later point, you call:

python
Copy
Edit
final_result = history_chain.invoke(input_dict)
or the chain you built is used as part of a larger chain that is .invoke(...)d. Only then does the system:

Take input_dict.
Pass it through the RunnablePassthrough.assign(...) logic.
Actually execute history_chain (the sub-runnable) to produce its output.
Assign that output to the specified dictionary key.
Return or feed that final dictionary to the next node.
Hence, the runnables aren’t executed right away during the chain construction — they’re merely configured. Execution only happens once the final chain is called with .invoke(...) (or .ainvoke(...), etc.). This is how LangChain’s “DAG” of runnables keeps everything lazy until you explicitly run it.

You said:
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
messages = hist.messages.copy()   How is this pulling messages from let's say a mongo database?
Reasoned for a couple of seconds
In that snippet:

python
Copy
Edit
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
messages = hist.messages.copy()
hist is already an instance of some concrete BaseChatMessageHistory subclass (e.g. MyMongoDBChatMessageHistory). That subclass’s messages property does the actual retrieval from Mongo (or whichever database). For example, it might look like:

python
Copy
Edit
class MyMongoDBChatMessageHistory(BaseChatMessageHistory):
    @property
    def messages(self) -> list[BaseMessage]:
        # code that queries Mongo and returns a list of messages
So calling hist.messages triggers your custom logic (like a Mongo query) under the hood. That’s why from the outside you just see hist.messages.copy(), but in practice, the property’s getter does the real DB read. Essentially:

hist is an object you set up in the config, referencing your “Mongo-based chat history” class.
When you do hist.messages, the property’s getter runs your code that fetches documents from Mongo, constructs them as BaseMessage objects, and returns them.
Then .copy() just clones that list in memory.
Hence, the retrieval from Mongo is abstracted away in your custom BaseChatMessageHistory subclass implementation.

You said:
What is the purpose of these two additional runnables:         runnable_sync: Runnable = runnable.with_listeners(on_end=self._exit_history)
        runnable_async: Runnable = runnable.with_alisteners(on_end=self._aexit_history)
Reasoned for a couple of seconds
In LangChain’s Runnable API, a single Runnable can have two ways of running:

Synchronous (.invoke(...))
Asynchronous (.ainvoke(...))
To support both modes, LangChain allows you to attach different listener callbacks that fire at different lifecycle events (like on_end) for each mode:

runnable.with_listeners(on_end=...) attaches callbacks for the sync execution path.
runnable.with_alisteners(on_end=...) attaches callbacks for the async execution path.
Hence:

python
Copy
Edit
runnable_sync: Runnable = runnable.with_listeners(on_end=self._exit_history)
runnable_async: Runnable = runnable.with_alisteners(on_end=self._aexit_history)
runnable_sync is the same underlying Runnable as runnable, but it has a sync “end” listener (on_end=self._exit_history) which is called once .invoke() finishes.
runnable_async is the same underlying Runnable but with an async “end” listener (on_end=self._aexit_history) which is called once .ainvoke() finishes.
Why do this?
If someone calls .invoke() (sync path), the code will eventually trigger on_end=self._exit_history – your method that handles final logic (like saving new messages to Mongo). But if someone calls .ainvoke() (async path), the chain eventually calls on_end=self._aexit_history – an async version that performs the same logic in a fully non-blocking way.

So effectively, the code is preparing two versions of the same chain:

One for synchronous usage, with a synchronous callback.
One for asynchronous usage, with an asynchronous callback.
You said:
In _exit_history, it actually adds the human and assistant message to mongo database:     def _exit_history(self, run: Run, config: RunnableConfig) -> None:
        hist: BaseChatMessageHistory = config["configurable"]["message_history"]

        # Get the input messages
        inputs = load(run.inputs)
        input_messages = self._get_input_messages(inputs)
        # If historic messages were prepended to the input messages, remove them to
        # avoid adding duplicate messages to history.
        if not self.history_messages_key:
            historic_messages = config["configurable"]["message_history"].messages
            input_messages = input_messages[len(historic_messages) :]

        # Get the output messages
        output_val = load(run.outputs)
        output_messages = self._get_output_messages(output_val)
        hist.add_messages(input_messages + output_messages)
Reasoned for a second
Yes, exactly. The _exit_history method runs after the entire chain finishes (i.e., once the final output is available). It retrieves:

The user’s input messages (which typically includes the new user/HumanMessage for this request).
The newly produced AI or assistant messages (from the chain’s final output).
Then it calls hist.add_messages(...), which in turn (in your custom Mongo-based class) persists those messages to the MongoDB collection.

So effectively, _exit_history merges:

The user’s new question
The newly generated LLM response
…into your conversation history in the database. This is why it’s an “on_end” listener – it only runs once the chain has produced all the needed messages.

You said:
This is my chatbot that uses a mix of runnables and graphs: from typing import (
    Annotated, 
    TypedDict, 
    Dict, 
    Literal, 
    Optional,
    TypeVar,
    Union,
    Iterator,
    AsyncIterator,
    AsyncGenerator,
    Sequence,
    TypeAlias,
    override,
    Self,
    Any,
    List,
    Tuple,
)
import os
import json
from bson import ObjectId
from collections import defaultdict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from langchain_core.runnables import (
    Runnable, 
    RunnableSerializable, 
    RunnableLambda,
    RunnablePassthrough,
    RunnableParallel,
    RunnableBranch,
    RunnableBinding,
)
from langchain_core.tracers.schemas import Run
from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompt_values import PromptValue, ChatPromptValue
from langchain_core.output_parsers.string import StrOutputParser
from langchain.chains.combine_documents.base import (
    DEFAULT_DOCUMENT_SEPARATOR, 
    DEFAULT_DOCUMENT_PROMPT,
)
from langchain_core.prompts import format_document
from langchain.chains.retrieval import create_retrieval_chain
from langchain_core.documents import Document
from langchain_core.messages import (
    SystemMessage, 
    HumanMessage, 
    AIMessage, 
    BaseMessage,    
    AIMessageChunk,
)
from langchain_core.messages import AIMessage, MessageLikeRepresentation
from langchain_core.messages.utils import AnyMessage
from langchain.chat_models.base import BaseChatModel
from langchain_core.embeddings import Embeddings
from langchain_core.outputs import (
    ChatGeneration, 
    ChatGenerationChunk
)
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.history import RunnableWithMessageHistory
from pydantic import (
    BaseModel,
    Field, 
    model_validator,
    ConfigDict
)
from langchain_core.retrievers import RetrieverLike
from pymongo import DESCENDING

from .graph_state import State
from .language_models.huggingface import HuggingFaceInference
from .chat_bot_config import ChatBotConfig
# from .local_tools.route_query_tool import RouteQueryTool

from langchain_redis import RedisConfig
from redisvl.query.filter import Tag, FilterExpression
from ..gwblue_redis_vectoretriever.config import VectorStoreSchema
from ..gwblue_redis_vectoretriever.vectorstore import RedisVectorStoreTTL

from .prompts import registry
from .message_history import (
    MongoMessageHistorySchema, 
    MongoMessageHistory, 
)

ChatGenerationLike: TypeAlias = ChatGeneration | Iterator[ChatGeneration] | AsyncIterator[ChatGenerationChunk]

I = TypeVar('I', bound=Union[PromptValue, str, Sequence[MessageLikeRepresentation]])
O = TypeVar('O', bound=ChatGenerationLike)
C = TypeVar('C', bound=BaseChatModel)
S = TypeVar('S', bound=BaseChatModel)

def _clamp_temperature(temp: float) -> float:
    if temp < 1.0:
        return 1.0
    elif temp >= 5.0:
        return 0.0
    else:
        return max(0.0, min(1.0, temp))    

def _textualize_model_input(content: Any) -> str:
    if isinstance(content, str):
        return content

    elif isinstance(content, list):
        text_chunks = []
        for item in content:
            if isinstance(item, dict) and item.get('type') == 'text':
                text_value = item.get('text', '')
                text_chunks.append(text_value)
        return " ".join(text_chunks)

    else:
        return ""
    
def _chunk_pairs(pairs: List[Dict[str, Any]]):
    if len(pairs) % 2 != 0:
        raise ValueError('pairs must contain an even number of elements.')
    for i in range(0, len(pairs), 2):
        yield pairs[i : i + 2]
        
class StreamingResponse(BaseModel):
    type: str
    content: str
    token_usage: Dict[str, Any]
    tool_calls: List[Dict[str, Any]]
    vector_metadata: List[Dict[str, Any]]
    session_id: str
    message_id: str

class ChatBot(RunnableSerializable[I, O]):
    config: ChatBotConfig
    graph: Optional[CompiledStateGraph] = None
    chat_model: BaseChatModel = Field(default=None, exclude=True)
    retry_model: BaseChatModel = Field(default=None, exclude=True)
    safety_model: BaseChatModel = Field(default=None, exclude=True)
    embeddings: Embeddings = Field(default=None, exclude=True)
    vector_store: RedisVectorStoreTTL = Field(default=None, exclude=True)
    message_history: MongoMessageHistory = Field(default=None, exclude=True)

    alt: Optional[bool] = False

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
    )

    @model_validator(mode='after')
    def load_environment(self) -> Self:
        graph = StateGraph(State)

        hf = HuggingFaceInference(config=self.config,model_types={})
        inference_engine = {
            'tgi': hf,
            'tei': hf,
            'vllm': None,
        }
        
        self.chat_model = inference_engine[self.config.llm.server]('chat_model')
        self.safety_model = inference_engine[self.config.guardrails.server]('guardrails')
        self.embeddings = inference_engine[self.config.embeddings.server]('embeddings')

        if self.alt:
            self.chat_model.llm = self.chat_model.llm.bind(
                temperature=_clamp_temperature(self.config.llm.parameters['temperature']),
                seed=42,
            )

        config = RedisConfig(**{
            'redis_client': self.config.vectorstore.client,
            'metadata_schema': self.config.vectorstore.metadata_schema,
            'embedding_dimensions': self.config.embeddings.dimensions,
            **VectorStoreSchema().model_dump()
        })
            
        self.vector_store = RedisVectorStoreTTL(
            self.embeddings, 
            config=config
        )

        message_schema = MongoMessageHistorySchema(
            session_id=self.config.message_history.session_id,
            connection_string=self.config.message_history.url,
            database_name=self.config.message_history.name,
            collection_name=self.config.message_history.collection_name,
            session_id_key=self.config.message_history.session_id_key,
        )
        self.message_history = MongoMessageHistory(message_schema)

        self.graph = self._compile(graph)

        return self
    
    @property
    @override
    def InputType(self) -> TypeAlias:
        from langchain_core.prompt_values import (
            ChatPromptValueConcrete,
            StringPromptValue,
        )

        return Union[
            str,
            Union[StringPromptValue, ChatPromptValueConcrete],
            list[AnyMessage],
        ]

    async def invoke(
        self,
        input: I,
        config: Optional[RunnableConfig] = None,
        *,
        stop: Optional[list[str]] = None,
        **kwargs: Any,
    ) -> O:
        """Implementation soon"""
        ...

    async def ainvoke(
        self,
        input: I,
        config: Optional[RunnableConfig] = None,
        *,
        stop: Optional[list[str]] = None,
        **kwargs: Any,
    ) -> O:
        """Implementation soon"""
        ...

    async def stream(
        self,
        input: I,
        config: Optional[RunnableConfig] = None,
        *,
        stop: Optional[list[str]] = None,
        **kwargs: Any,
    ) -> O:
        """Implementation soon"""
        ...

    async def astream(
        self,
        input: I,
        config: Optional[Dict[str, Any]] = None,
        *,
        stop: Optional[list[str]] = None,
        **kwargs: Any,
    ) -> AsyncGenerator[StreamingResponse, None]:
        if isinstance(input, ChatPromptValue):
            input = input.to_messages()
    
        state = {
            'messages': input,
            'metadata': config['metadata'].get('vector_metadata', []),
            'retrieval_mode': config['configurable'].get('retrieval_mode', 'similarity'),
        }        

        async for event in self.graph.astream(state, config, stream_mode='messages'):
            ai_message_chunk, state_args = event
            if not isinstance(ai_message_chunk, AIMessageChunk):
                continue

            yield StreamingResponse(**{
                'type': ai_message_chunk.__class__.__name__,
                'content': ai_message_chunk.content,
                'token_usage': ai_message_chunk.additional_kwargs.get('token_usage', {}),
                'tool_calls': ai_message_chunk.additional_kwargs.get('tool_calls', []),
                'vector_metadata': state_args.get('vector_metadata', []),
                'session_id': str(self.config.message_history.session_id),
                'message_id': ai_message_chunk.additional_kwargs.get('uuid', ''),
            })
    
    @staticmethod
    def preprompt_filter(state: State, metadata: Dict[str, Any]) -> RunnableLambda:
        def create_preprompt_filter(input_data: Dict[str, Any]) -> Dict[str, Any]:
            return {
                **input_data,
                'chat_history': [
                    message for message in input_data.get('chat_history', [])
                    if not isinstance(message, SystemMessage) or not message.additional_kwargs.get('preprompt', False)
                ]
            }
        
        return RunnableLambda(create_preprompt_filter).with_config(
            run_name=f'filter_preprompt_chain_{state['route']}',
            metadata=metadata
        )
    
    @staticmethod
    def create_filter_expression(metadata: Dict[str, Any]) -> FilterExpression:
        from functools import reduce
        import operator

        tag_expressions = [
            Tag(key) == str(value)
            for key, value in metadata.items() 
        ]
        filter_expression = reduce(operator.and_, tag_expressions)
        return filter_expression

    def create_generic_chain(self, state: State, system_prompt: str) -> Runnable:
        """
        Conform to answer key produced by create_retrieval_chain
        using a custom output parser

        Important note: passing only content in output parser instead of entire
        AIMessage means that any metadata associated with that AIMessage is lost,
        scuh as token_usage or logprobs
        """
        answer_parser = RunnableLambda(lambda ai_message: {'answer': ai_message.content } )
        chain = registry['chat_prompt_with_history'](system_prompt) | self.chat_model | answer_parser
        
        return chain.with_config(run_name=f'generic_chat_model_chain_{state['route']}')    

    def create_history_aware_retriever(
        self,
        retriever: RetrieverLike,
        prompt: BasePromptTemplate,
        preprompt_filter: Optional[Runnable] = None,
    ) -> Runnable:
        """Custom implementation to handle preprompt messages"""
        def validate_history(input_data: Dict[str, Any]) -> bool:
            return not input_data.get('chat_history')
            
        retrieve_documents = (preprompt_filter or RunnablePassthrough()) | RunnableBranch(
            (
                validate_history,
                (lambda input_data: input_data['input']) | retriever,
            ),
            prompt
            | self.chat_model
            | StrOutputParser()
            | retriever,
        ).with_config(run_name='history_aware_retriever_chain')
        
        return retrieve_documents

    def create_stuff_documents_chain(
        self,
        prompt: BasePromptTemplate,
        preprompt_filter: Optional[Runnable] = None,
    ) -> Runnable[Dict[str, Any], Any]:
        """Custom implementation to handle preprompt messages"""        
        def format_docs(inputs: dict) -> str:
            return DEFAULT_DOCUMENT_SEPARATOR.join(
                format_document(doc, DEFAULT_DOCUMENT_PROMPT)
                for doc in inputs['context']
            )

        return (
            (preprompt_filter or RunnablePassthrough())
            | RunnablePassthrough.assign(context=format_docs).with_config(run_name='format_inputs')
            | prompt
            | self.chat_model
            | StrOutputParser()
        ).with_config(run_name='stuff_documents_chain')    

    def create_context_aware_chain(self, state: State) -> Runnable:
        metadata = state['metadata'][0]
        system_prompt = state['messages'][0].content

        filter_expression = self.create_filter_expression(metadata)
        search_kwargs = {
            'k': 6,
            'filter': filter_expression,
        }
        if state['retrieval_mode'] == 'similarity_score_threshold':
            search_kwargs['score_threshold'] = 0.8

        retriever = self.vector_store.as_retriever(
            search_type=state['retrieval_mode'],
            search_kwargs=search_kwargs
        ).with_config(
            tags=[f'create_context_aware_chain_{state['route']}'],
            metadata=metadata,           
        )
        
        history_aware_retriever = self.create_history_aware_retriever(
            retriever,
            registry['contextualized_template'](),
            preprompt_filter=self.preprompt_filter(state, metadata),
        )

        question_answer_chain = self.create_stuff_documents_chain(
            registry['qa_template'](system_prompt),
            preprompt_filter=self.preprompt_filter(state, metadata)
        )
        
        return create_retrieval_chain(history_aware_retriever, question_answer_chain)

    def create_multi_retriever_chain(
        self, 
        retrievers: Tuple[List[Runnable], Dict[str, Any]],
        state: State
    ) -> Runnable:
        context_prompt = registry['contextualized_template']()
        retriever_map = {
            f'Source {metadata['source']}': retriever
            for retriever, metadata in retrievers
        }
        parallel_retrieval = RunnableParallel(retriever_map)

        def combine_contexts(retrieved_results: dict, separator=DEFAULT_DOCUMENT_SEPARATOR) -> list:
            combined_results = []
            for key, docs in retrieved_results.items():
                combined_docs = separator.join(doc.page_content for doc in docs)
                combined_results.append(
                    Document(page_content=f'Context from {key}:\n{combined_docs}')
                )

            return combined_results
        
        combine_contexts_runnable = RunnableLambda(combine_contexts) \
            .with_config(run_name='combine_context_chain')

        retrieve_documents = self.create_history_aware_retriever(
            parallel_retrieval,
            context_prompt,
            preprompt_filter=self.preprompt_filter(state, {})
        )
        
        return retrieve_documents | combine_contexts_runnable
    
    def create_multi_stuff_chain(self, state: State, system_prompt: str) -> Runnable:
        qa_template = registry['qa_template'](system_prompt)

        return self.create_stuff_documents_chain(
            qa_template,
            preprompt_filter=self.preprompt_filter(state, {})
        )

    def create_multicontext_aware_chain(self, state: State) -> Runnable:
        system_prompt = state['messages'][0].content
        retrievers = []
        
        for index, metadata in enumerate(state['metadata']):
            filter_expression = self.create_filter_expression(metadata)
            search_kwargs = {
                'k': 6,
                'filter': filter_expression,
            }
            if state['retrieval_mode'] == 'similarity_score_threshold':
                search_kwargs['score_threshold'] = 0.8

            retriever = self.vector_store.as_retriever(
                search_type=state['retrieval_mode'],
                search_kwargs=search_kwargs
            ).with_config(
                tags=[f'create_context_aware_chain_{index}_{state['route']}'],
                metadata=metadata,           
            )
            retrievers.append((retriever, metadata))

        multi_retriever_chain = self.create_multi_retriever_chain(retrievers, state)
        stuffing_chain = self.create_multi_stuff_chain(state, system_prompt)
        
        multicontext_aware_chain = (
            RunnablePassthrough.assign(
                context=multi_retriever_chain.with_config(run_name='retrieval_chain'),
            ).assign(answer=stuffing_chain)
        ).with_config(run_name='multicontext_aware_chain')

        return multicontext_aware_chain

    async def _aenter_chat_chain(self, run: Run, config: RunnableConfig, system_prompt: str) -> Optional[SystemMessage]:
        """On start runnable listener"""
        collection = self.message_history.chat_message_history.collection
        
        document = collection.find_one({
            'type': 'system', 
            'content': system_prompt, 
            self.message_history._schema.session_id_key: self.message_history._schema.session_id,
        })
        
        if document is None:
            await self.message_history.asystem(system_prompt, additional_kwargs={'preprompt': True})
        else:
            history_data = json.loads(document['History'])
            
            additional_kwargs = history_data.get('data', {}).get('additional_kwargs', {})
            if not additional_kwargs.get('preprompt', False):
                await self.message_history.asystem(system_prompt, additional_kwargs={'preprompt': True})

    async def _aexit_chat_chain(self, run: Run, config: RunnableConfig) -> None:
        """On end runnable listener"""
        collection = self.message_history.chat_message_history.collection
        if(
            ai_message := collection.find_one(
                {
                    'type': { '$in': ['ai', 'AIMessageChunk'] }, 
                    self.config.message_history.session_id_key: self.config.message_history.session_id,
                }, 
                sort=[("createdAt", DESCENDING)])
        ) is not None:
            chain = registry['summarization_template']() | self.chat_model.bind(stream=False)
            summary = await chain.ainvoke({'input': ai_message['content']})
            self.message_history.chat_message_history.add_summary(summary.content)    

    async def generate_with_history(
        self, 
        state: State, 
        chain: Runnable,
        *,
        config: Optional[RunnableConfig] = None,
    ) -> dict:
        system_prompt = state['messages'][0].content
        human_message = state['messages'][-2]

        if isinstance(human_message.content, str):
            input_dict = {'input': human_message.content}
        elif isinstance(human_message.content, list):
            input_dict = { 'input': human_message.content[1]['text'] }

        async def on_start(run: Run, config: RunnableConfig):
            await self._aenter_chat_chain(run, config, system_prompt)

        async def on_end(run: Run, config: RunnableConfig):
            await self._aexit_chat_chain(run, config)

        chain_with_history = self.message_history.get(chain).with_alisteners(
            on_start=on_start,
            on_end=on_end
        )

        if (
            not config
            or 'configurable' not in config.configurable
            or 'session_id' not in config.configurable['configurable']
        ):
            config = RunnableConfig(
                configurable={
                    'configurable': { 'session_id': self.config.message_history.session_id } # TODO: generalize to session_id
                }
            )
        chain_values = await chain_with_history.ainvoke(input_dict, config=config['configurable'])
        return { 'messages': [AIMessage(content=chain_values['answer'])] }
    
    def _compile(self, graph: StateGraph):
        async def guardrails(state: State) -> State:
            user_content = state['messages'][1].content
            sanitized_text = _textualize_model_input(user_content)
            ai_message = await self.safety_model.bind(stream=False).ainvoke([sanitized_text])

            guardrails_message = AIMessage(
                content=ai_message.content,
                additional_kwargs={'guardrails': True},
            )
            return {**state, 'messages': [guardrails_message]}

        def guardrails_condition(state: State) -> str:
            last_msg: AIMessage = state["messages"][-1]
            text = last_msg.content.lower().strip('\n')
            if 'unsafe' in text:
                return 'not_safe'
            elif 'safe' in text:
                return 'prefill_system_prompt'
            else:
                return 'not_safe'

        def not_safe(_: State):
            return {
                "messages": [
                    AIMessageChunk(
                        content="Your request cannot be processed. (Content flagged as not safe.)"
                    )
                ]
            }

        async def prefill_system_prompt(state: State) -> State:
            import copy

            system_message = state['messages'][0]
            human_message = state['messages'][-2]
            
            if not isinstance(human_message.content, list):
                return state

            pairs = list(_chunk_pairs(human_message.content))
            batches = []
            for pair in pairs:
                cloned_pair = [copy.deepcopy(item) for item in pair]
                cloned_pair[1]['text'] = 'Describe the image in-depth.'
                messages = [
                    SystemMessage(content=system_message.content),
                    HumanMessage(content=cloned_pair),
                ]
                batches.append(messages)
            
            non_streaming_model = self.chat_model.bind(stream=False)
            ai_messages = await non_streaming_model.abatch(batches)

            labeled_descs = []
            for i, ai_message in enumerate(ai_messages, start=1):
                desc = ai_message.content.strip()
                labeled_descs.append(f"**Image #{i}**:\n{desc}")
            combined_desc = "\n\n".join(labeled_descs)

            old_prompt = system_message.content
            updated_prompt = (
                old_prompt
                + "\n\n"
                + "If the user asks about images, pictures, or photos, then use the following descriptions as if they were the actual images:\n"
                + combined_desc
            )

            new_system_message = system_message.copy(update={'content': updated_prompt})
            state['messages'][0] = new_system_message

            return state

        async def route_query(state: State):
            """
            Account for scenarios:
            - 'Explain this document' (where this is not specified and refers to upload)
            - 'Compare these two documents' (where these is not specified and refers to multiple uploads)
            - 'Explain why it has impact' (where it refers to previously vectorized data)
            - 'Compare this and that' (where this is current upload and that is previously vectorized data)
            - 'Explain something' (where no vector data but can tap into pretrained corpus of LLM)
            """
            metadata = state['metadata']

            if len(metadata) > 1:
                return {'route': 'multi_doc_prompt', **state }
            
            if len(metadata) == 1 and 'source' in metadata[0]:
                return {"route": 'single_doc_prompt', **state}
        
            human_prompt = _textualize_model_input(state['messages'][-2].content)
            vector_filter = metadata[0]

            filter_expression = (
                (Tag(self.config.message_history.session_id_key) == str(self.config.message_history.session_id)) 
                & (Tag('uuid') == vector_filter['uuid'])
            )
            relevant_docs_with_score = await self.vector_store.asimilarity_search_with_score(
                query=human_prompt,
                k=20,
                filter=filter_expression,
            )
            file_to_best_chunk = defaultdict(lambda: (None, float('inf')))
            for doc, dist in relevant_docs_with_score:
                fname = doc.metadata.get('source', '')                
                if dist < file_to_best_chunk[fname][1]:
                    file_to_best_chunk[fname] = (doc, dist)

            best_metadata = [doc.metadata for (doc, _) in file_to_best_chunk.values() if doc is not None]

            num_files = len(best_metadata)
            if num_files > 1:
                route = 'multi_doc_prompt'
            elif num_files == 1:
                route = 'single_doc_prompt'
            else:
                route = 'pretrained_corpus_prompt'
            
            return {
                **state,
                'route': route,
                'metadata': best_metadata,
            }  

        def route_query_condition(state: State) -> str:
            return state['route']
        
        async def single_doc_prompt(state: State) -> dict:
            """
            Generate prompt for single document
            """
            chain = self.create_context_aware_chain(state)
            messages = await self.generate_with_history(state, chain)
            return messages

        async def multi_doc_prompt(state: State) -> Dict[str, Any]:
            """
            Generate prompt for multiple documents
            """
            chain = self.create_multicontext_aware_chain(state)
            messages = await self.generate_with_history(state, chain)
            return messages

        async def pretrained_corpus_prompt(state: State) -> Dict[str, Any]:
            """
            Generate prompt for pretrained corpus
            """
            system_prompt = state['messages'][0].content
            chain = self.create_generic_chain(state, system_prompt)
            messages = await self.generate_with_history(state, chain)
            return messages

        graph.add_node('guardrails', guardrails)
        graph.add_node('not_safe', not_safe)
        graph.add_node('prefill_system_prompt', prefill_system_prompt)
        graph.add_node('route_query', route_query)
        graph.add_node('single_doc_prompt', single_doc_prompt)
        graph.add_node('multi_doc_prompt', multi_doc_prompt)
        graph.add_node('pretrained_corpus_prompt', pretrained_corpus_prompt)

        graph.add_edge(START, 'guardrails')
        graph.add_conditional_edges(
            'guardrails',
            guardrails_condition,
            {
                'prefill_system_prompt': 'prefill_system_prompt',
                'not_safe': 'not_safe'
            }
        )
        graph.add_edge('prefill_system_prompt', 'route_query')
        graph.add_edge('not_safe', END)
        graph.add_conditional_edges(
            'route_query',
            route_query_condition,
            {
                #'vectorstore': 'vectorstore',
                #'dataframe_tool': 'dataframe_tool',
                'single_doc_prompt': 'single_doc_prompt',
                'multi_doc_prompt': 'multi_doc_prompt',
                'pretrained_corpus_prompt': 'pretrained_corpus_prompt',
            }
        )
        #graph.add_edge('vectorstore', END)
        #graph.add_edge('dataframe_tool', END)
        graph.add_edge('single_doc_prompt', END)
        graph.add_edge('multi_doc_prompt', END)
        graph.add_edge('pretrained_corpus_prompt', END)
        return graph.compile()


        """
        Determine if multiple candidate completions (
        use both semantic similarity and max marginal
        relevance
        )
        Go to vector store find vectors based on the single
        doc
        AND ALSO find vectors based on the query for the given
        conversation_id (just in case they ask to compare with
        a previous doc)
        """    The graph implementation is newer concept, slowly migrating towards it. As you can see, I build up a runnable chain that I later pass to RunnableWithMessageHistory. As an example let's look at the path of single_doc_prompt:         async def single_doc_prompt(state: State) -> dict:
            """
            Generate prompt for single document
            """
            chain = self.create_context_aware_chain(state)
            messages = await self.generate_with_history(state, chain)
            return messages     In the method create_context_aware_chain that is the chain that will be passed to the RunnableWithMessageHistory. How is RunnableWithMessageHistory able to detect which aimessage in the chain to store in database and not. For example: in the create_history_aware_retriever, we invoke an llm:     def create_history_aware_retriever(
        self,
        retriever: RetrieverLike,
        prompt: BasePromptTemplate,
        preprompt_filter: Optional[Runnable] = None,
    ) -> Runnable:
        """Custom implementation to handle preprompt messages"""
        def validate_history(input_data: Dict[str, Any]) -> bool:
            return not input_data.get('chat_history')
            
        retrieve_documents = (preprompt_filter or RunnablePassthrough()) | RunnableBranch(
            (
                validate_history,
                (lambda input_data: input_data['input']) | retriever,
            ),
            prompt
            | self.chat_model
            | StrOutputParser()
            | retriever,
        ).with_config(run_name='history_aware_retriever_chain')
        
        return retrieve_documents   All the llm does is it takes the prompt, it pulls some history and adds context to prompt such as replacing "it" with specific context. Yet this model invocation is never stored as the new created aimessage. How does RunnableWithHistory know now to create this AIMessage but then later creates the other in here:     def create_stuff_documents_chain(
        self,
        prompt: BasePromptTemplate,
        preprompt_filter: Optional[Runnable] = None,
    ) -> Runnable[Dict[str, Any], Any]:
        """Custom implementation to handle preprompt messages"""        
        def format_docs(inputs: dict) -> str:
            return DEFAULT_DOCUMENT_SEPARATOR.join(
                format_document(doc, DEFAULT_DOCUMENT_PROMPT)
                for doc in inputs['context']
            )

        return (
            (preprompt_filter or RunnablePassthrough())
            | RunnablePassthrough.assign(context=format_docs).with_config(run_name='format_inputs')
            | prompt
            | self.chat_model
            | StrOutputParser()
        ).with_config(run_name='stuff_documents_chain')  in both cases, we use self.chat_model which is of type BaseChatModel yet one gets created as aimessage, the other not.

Reasoned for 8 seconds
In RunnableWithMessageHistory, only the chain’s final output gets added to the chat history – not every intermediate LLM call. In other words:

Every LLM invocation in your chain (e.g. inside create_history_aware_retriever or create_stuff_documents_chain) might produce an AIMessage internally as part of the pipeline.
But RunnableWithMessageHistory looks only at the final return of the chain once it finishes. It then identifies the user’s input message(s) plus the final AI message(s) in that final output. Those are what it stores in your database.
How RunnableWithMessageHistory does this
Attaching a “listener”
RunnableWithMessageHistory wraps your chain with an “on_end” (or “on_end_async”) listener that runs _exit_history (or _aexit_history).
In _exit_history:
python
Copy
Edit
# 1) Get the final run outputs
output_val = load(run.outputs)
output_messages = self._get_output_messages(output_val)

# 2) Store them with hist.add_messages(...)
_get_output_messages() looks for an AIMessage (or list of them) in the final chain output. By default, if your chain returns a dictionary with 'answer': AIMessage(...), it sees that as the final AI message.
Intermediate LLM calls never appear in that final dictionary. They’re ephemeral steps in the pipeline. So _exit_history doesn’t see them, and thus never writes them to Mongo.
Concretely in your example
create_history_aware_retriever(...) calls self.chat_model to produce some text. That partial text is used to refine a prompt or retrieve documents – but the chain eventually moves on.
The output from that partial LLM call is not the final chain return. Instead, your chain might produce something like {"context": <list of docs>}.
Eventually you reach create_stuff_documents_chain(...): ... | self.chat_model | StrOutputParser(). That portion typically yields the final text answer. The chain’s final return might look like {"answer": "Here’s your final answer"} or a single AIMessage(...).
Because RunnableWithMessageHistory sees only that final object as the “chain output,” it’s the only one stored in the DB.
Why only the final answer is stored
By default, RunnableWithMessageHistory is designed to store one “user input” plus one “assistant output” for each chain invocation. It does not track every intermediate node.
If you want to store multiple LLM calls, you’d have to either:
Return them in the chain’s final dictionary so _exit_history sees them, or
Create custom logic in your chain to do partial updates to the DB.
But out of the box, RunnableWithMessageHistory is meant to store “the user’s question” and “the final AI answer” from each chain run. All other AI calls along the way remain internal.

You said:
When this chain is invoked, exactly what happens:         def format_docs(inputs: dict) -> str:
            return DEFAULT_DOCUMENT_SEPARATOR.join(
                format_document(doc, DEFAULT_DOCUMENT_PROMPT)
                for doc in inputs['context']
            )

        return (
            (preprompt_filter or RunnablePassthrough())
            | RunnablePassthrough.assign(context=format_docs).with_config(run_name='format_inputs')
            | prompt
            | self.chat_model
            | StrOutputParser()
        ).with_config(run_name='stuff_documents_chain')    

Reasoned about `RunnableWithMessageHistory` behavior for 7 seconds
Here’s the high-level flow when you call, for example, my_chain.invoke({"context": [...docs...]}) on that snippet:

python
Copy
Edit
def format_docs(inputs: dict) -> str:
    return DEFAULT_DOCUMENT_SEPARATOR.join(
        format_document(doc, DEFAULT_DOCUMENT_PROMPT)
        for doc in inputs["context"]
    )

chain = (
    (preprompt_filter or RunnablePassthrough())
    | RunnablePassthrough.assign(context=format_docs)
    | prompt
    | self.chat_model
    | StrOutputParser()
).with_config(run_name="stuff_documents_chain")
Input → (preprompt_filter or RunnablePassthrough())

You pass in some dictionary (for example: {"context": [Document(...), ...], "input": "..."}).
If preprompt_filter exists (another Runnable), it processes that input. Otherwise, RunnablePassthrough() simply passes the dictionary unchanged.
The result is still a dictionary at this point.
→ RunnablePassthrough.assign(context=format_docs)

Now the chain calls format_docs(inputs) on the dictionary from step 1.
format_docs iterates over inputs["context"] (which should be a list of Documents), merges them into a string, and returns it.
Then .assign(context=...) sets that result as the new "context" field in the dictionary.
So if you originally had {"context": [Document1, ...], "input": "..."}, by the end of this step you have something like:
python
Copy
Edit
{
  "context": "<formatted string of docs>",
  "input": "...",
  ... # plus other keys from the original
}
→ prompt

Next, the dictionary (with "context" now a string) is fed into your BasePromptTemplate (named prompt).
That template typically expects certain keys (for example, it might have placeholders like {context}, {input} in the template).
It substitutes them to build a PromptValue or ChatPromptValue, i.e. the prompt text/messages for the LLM.
→ self.chat_model

The ChatPromptValue from step 3 is passed to your model (self.chat_model).
The model then generates a response (for example, an AIMessage), presumably applying the newly combined context as part of the prompt.
→ StrOutputParser()

Finally, the LLM’s output (an AIMessage or text) goes through StrOutputParser(), which usually extracts just the .content as a string.
That is the final output of the chain. If you call chain.invoke(...), you end up with a plain string from this parser.
So in summary:

Possibly filter/transform the input dictionary with preprompt_filter.
Convert the list of documents in "context" into one long text string (format_docs), store that back in "context".
Use that dictionary to fill the prompt template.
Pass the prompt to self.chat_model to generate a textual response.
Parse that response to a string with StrOutputParser.
Return that final string to the caller.
You said:
Where are the format_docs coming from? Where are they being loaded?

Reasoned for 10 seconds
In your snippet:

python
Copy
Edit
return (
    (preprompt_filter or RunnablePassthrough())
    | RunnablePassthrough.assign(context=format_docs).with_config(run_name='format_inputs')
    | prompt
    | self.chat_model
    | StrOutputParser()
)
the call format_docs(inputs) relies on inputs["context"] – meaning your pipeline expects to see a dictionary (the “inputs”) containing a key "context" which is a list of Document objects. Something upstream in the chain is responsible for putting that list of documents in "context". Typically, it’s one of two scenarios:

The chain’s caller provides the documents when calling .invoke() or .ainvoke().
For example:

python
Copy
Edit
docs = [Document(page_content="Doc 1..."), Document(page_content="Doc 2...")]
output = my_chain.invoke({"context": docs, "input": "Please summarize"})
Then format_docs sees docs in inputs["context"].

A previous step in the chain stored documents under "context".
For instance, if you have a retrieval step like retriever.invoke(...) returning [Document(...), ...], you might do:

python
Copy
Edit
chain_part = RunnablePassthrough.assign(context=retriever)  # store retriever's doc list
# ...
chain_part | (prompt | self.chat_model)
In that case, the retrieval step’s output becomes the "context" field for the next step.

Therefore, the “format_docs” function sees inputs["context"] because something earlier in the pipeline attached a list of Documents to that key. That “something” might be a retrieval chain, user input, or any other node in the pipeline that sets "context" in the dictionary.

You said:
And create_retriever_chain fills in that resulting text under answer: def create_retrieval_chain(
    retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]],
    combine_docs_chain: Runnable[Dict[str, Any], str],
) -> Runnable:
    """Create retrieval chain that retrieves documents and then passes them on.

    Args:
        retriever: Retriever-like object that returns list of documents. Should
            either be a subclass of BaseRetriever or a Runnable that returns
            a list of documents. If a subclass of BaseRetriever, then it
            is expected that an input key be passed in - this is what
            is will be used to pass into the retriever. If this is NOT a
            subclass of BaseRetriever, then all the inputs will be passed
            into this runnable, meaning that runnable should take a dictionary
            as input.
        combine_docs_chain: Runnable that takes inputs and produces a string output.
            The inputs to this will be any original inputs to this chain, a new
            context key with the retrieved documents, and chat_history (if not present
            in the inputs) with a value of [] (to easily enable conversational
            retrieval.

    Returns:
        An LCEL Runnable. The Runnable return is a dictionary containing at the very
        least a context and answer key.

    Example:
        .. code-block:: python

            # pip install -U langchain langchain-community

            from langchain_community.chat_models import ChatOpenAI
            from langchain.chains.combine_documents import create_stuff_documents_chain
            from langchain.chains import create_retrieval_chain
            from langchain import hub

            retrieval_qa_chat_prompt = hub.pull("langchain-ai/retrieval-qa-chat")
            llm = ChatOpenAI()
            retriever = ...
            combine_docs_chain = create_stuff_documents_chain(
                llm, retrieval_qa_chat_prompt
            )
            retrieval_chain = create_retrieval_chain(retriever, combine_docs_chain)

            retrieval_chain.invoke({"input": "..."})

    """
    if not isinstance(retriever, BaseRetriever):
        retrieval_docs: Runnable[dict, RetrieverOutput] = retriever
    else:
        retrieval_docs = (lambda x: x["input"]) | retriever

    retrieval_chain = (
        RunnablePassthrough.assign(
            context=retrieval_docs.with_config(run_name="retrieve_documents"),
        ).assign(answer=combine_docs_chain)
    ).with_config(run_name="retrieval_chain")

    return retrieval_chain    But then if the result is text, again how does RunnableWithMessageHistory know the AIMessage if all that chain outputs is text?
Reasoned for a second
Because RunnableWithMessageHistory has a small piece of logic that, at the very end, converts the chain’s final output (the text under "answer") into an AIMessage. Specifically, its _get_output_messages(...) method:

Sees a dictionary like {"context": [...], "answer": "some text from the chain"}.
Because output_messages_key="answer", it takes that string "some text from the chain" and converts it to an AIMessage(content="some text from the chain").
Then it stores that AIMessage in the database.
So from your chain’s perspective, all it returns is a dictionary with a string under "answer". But RunnableWithMessageHistory knows, “The final message is at output_messages_key='answer'; if that is a string, I wrap it as an AIMessage(content=that_string).” That’s how the text becomes an AIMessage for your chat history.

You said:
What is the eventual output of this chain:     def create_history_aware_retriever(
        self,
        retriever: RetrieverLike,
        prompt: BasePromptTemplate,
        preprompt_filter: Optional[Runnable] = None,
    ) -> Runnable:
        """Custom implementation to handle preprompt messages"""
        def validate_history(input_data: Dict[str, Any]) -> bool:
            return not input_data.get('chat_history')
            
        retrieve_documents = (preprompt_filter or RunnablePassthrough()) | RunnableBranch(
            (
                validate_history,
                (lambda input_data: input_data['input']) | retriever,
            ),
            prompt
            | self.chat_model
            | StrOutputParser()
            | retriever,
        ).with_config(run_name='history_aware_retriever_chain')
        
        return retrieve_documents

Reasoned about create_history_aware_retriever chain output
Looking at this snippet:

python
Copy
Edit
def create_history_aware_retriever(
    self,
    retriever: RetrieverLike,
    prompt: BasePromptTemplate,
    preprompt_filter: Optional[Runnable] = None,
) -> Runnable:
    def validate_history(input_data: Dict[str, Any]) -> bool:
        return not input_data.get('chat_history')
            
    retrieve_documents = (preprompt_filter or RunnablePassthrough()) | RunnableBranch(
        (
            validate_history,
            (lambda input_data: input_data['input']) | retriever,
        ),
        prompt
        | self.chat_model
        | StrOutputParser()
        | retriever,
    ).with_config(run_name='history_aware_retriever_chain')
    
    return retrieve_documents
the eventual final output of this chain is a list of Documents produced by the retriever. Here’s how it works step by step:

(preprompt_filter or RunnablePassthrough()):

If you have a preprompt_filter runnable, it modifies the incoming dictionary; otherwise, the RunnablePassthrough leaves it as is.
| RunnableBranch(...):

RunnableBranch takes two branches and a condition function (validate_history):

Condition: validate_history(input_data) → bool
If validate_history returns True → use the first branch. Otherwise → use the second branch.
First branch:

python
Copy
Edit
(lambda input_data: input_data["input"]) | retriever
That effectively says:

Grab input_data["input"] (usually the user query).
Pass that string to retriever to get a list of Documents.
Second branch:

python
Copy
Edit
prompt
| self.chat_model
| StrOutputParser()
| retriever
That means we:

Format the dictionary with prompt (which might incorporate chat history or other template fields).
Call self.chat_model to get some generated text.
Parse that LLM text into a string (StrOutputParser()).
Finally, pass that string to retriever, which also returns a list of Documents.
At the end:
No matter which branch is taken, the final piece is always retriever – so the chain’s final output is the list of Documents returned by that retriever call.

Hence, whether we skip chat history or not, the chain returns a list of Documents that downstream steps can use.

You said:
Ok so one chain populates the documents for context. The other chain generates the chat completion. But of course that second chain will need access to the context of documents. I don't see how that is happening here: def create_retrieval_chain(
    retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]],
    combine_docs_chain: Runnable[Dict[str, Any], str],
) -> Runnable:
    """Create retrieval chain that retrieves documents and then passes them on.

    Args:
        retriever: Retriever-like object that returns list of documents. Should
            either be a subclass of BaseRetriever or a Runnable that returns
            a list of documents. If a subclass of BaseRetriever, then it
            is expected that an input key be passed in - this is what
            is will be used to pass into the retriever. If this is NOT a
            subclass of BaseRetriever, then all the inputs will be passed
            into this runnable, meaning that runnable should take a dictionary
            as input.
        combine_docs_chain: Runnable that takes inputs and produces a string output.
            The inputs to this will be any original inputs to this chain, a new
            context key with the retrieved documents, and chat_history (if not present
            in the inputs) with a value of [] (to easily enable conversational
            retrieval.

    Returns:
        An LCEL Runnable. The Runnable return is a dictionary containing at the very
        least a context and answer key.

    Example:
        .. code-block:: python

            # pip install -U langchain langchain-community

            from langchain_community.chat_models import ChatOpenAI
            from langchain.chains.combine_documents import create_stuff_documents_chain
            from langchain.chains import create_retrieval_chain
            from langchain import hub

            retrieval_qa_chat_prompt = hub.pull("langchain-ai/retrieval-qa-chat")
            llm = ChatOpenAI()
            retriever = ...
            combine_docs_chain = create_stuff_documents_chain(
                llm, retrieval_qa_chat_prompt
            )
            retrieval_chain = create_retrieval_chain(retriever, combine_docs_chain)

            retrieval_chain.invoke({"input": "..."})

    """
    if not isinstance(retriever, BaseRetriever):
        retrieval_docs: Runnable[dict, RetrieverOutput] = retriever
    else:
        retrieval_docs = (lambda x: x["input"]) | retriever

    retrieval_chain = (
        RunnablePassthrough.assign(
            context=retrieval_docs.with_config(run_name="retrieve_documents"),
        ).assign(answer=combine_docs_chain)
    ).with_config(run_name="retrieval_chain")

    return retrieval_chain
Reasoned for a couple of seconds
In create_retrieval_chain, the “combine docs” step does receive access to the retrieved documents, thanks to the line:

python
Copy
Edit
RunnablePassthrough.assign(
    context=retrieval_docs.with_config(run_name="retrieve_documents")
).assign(answer=combine_docs_chain)
Here’s how that works step by step:

retrieval_docs (the first half) returns a list of documents from the retriever.
RunnablePassthrough.assign(context=...) means:
“Take the original dictionary that came in (which probably had keys like {"input": "...", ...}),
Then add/override a new key "context" that gets its value from retrieval_docs.”
So now your dictionary looks like
python
Copy
Edit
{
  "input": "...",
  ...,
  "context": [Document(...), Document(...), ...]
}
Next step: .assign(answer=combine_docs_chain) means we call combine_docs_chain with that entire dictionary (which now has "context").
Because combine_docs_chain expects to see inputs["context"], it can read the docs from the newly filled context field.
combine_docs_chain uses or summarizes those docs, returning something like a string final answer.
Hence, the “combine docs” chain sees "context" in the same dictionary because the “retrieval docs” are assigned to the "context" key before calling combine_docs_chain. That is exactly how the second step has access to the retrieved documents.

You said:
So this retrieves the documents to use as context:         history_aware_retriever = self.create_history_aware_retriever(
            retriever,
            registry['contextualized_template'](),
            preprompt_filter=self.preprompt_filter(state, metadata),
        )     ANd this populates the context in the prompt and sends it off to llm for chat completion:         question_answer_chain = self.create_stuff_documents_chain(
            registry['qa_template'](system_prompt),
            preprompt_filter=self.preprompt_filter(state, metadata)
        )    This is what I would like to do. I want to return multiple responses to the user but for one response instead of using the retriever that uses mmr, I want to use one that just uses semantic similarity, and i want to pass it off to the chat model that is has runnable binding with a different temperature that is different from what user set. I want the two branches to run in parallel.

Reasoned for 6 seconds
Here’s one way to do it using LangChain’s RunnableParallel, so you can produce two distinct responses in parallel – each with:

A different retrieval mode (e.g., MMR vs. similarity).
A different LLM parameter (e.g., a different temperature).
Then you return both outputs as a dictionary with two keys.

1) Build two sub-chains
Let’s say your normal chain is something like “retrieve docs, then run a ‘stuff documents’ chain with a certain prompt.” You can do two separate sub-chains, each with a different retrieval approach and/or different LLM parameters.

Sub-chain A: mmr retrieval, plus a specialized LLM

python
Copy
Edit
def create_mmr_chain(self, system_prompt: str, mmr_temp: float = 0.0) -> Runnable:
    # 1) an mmr retriever
    mmr_retriever = self.vector_store.as_retriever(
        search_type="mmr",  # or pass 'mmr' in search_kwargs if your VectorStore expects that
        search_kwargs={"k": 6, "lambda_mult": 0.2}  # or whatever you want
    )

    # 2) a custom LLM binding with a different temperature
    mmr_chat_model = self.chat_model.bind(temperature=mmr_temp)

    # 3) chain that assigns the "context" with docs, then calls a "stuff chain"
    # e.g. we can reuse your create_stuff_documents_chain with a new LLM
    # but we need a separate step to do the retrieval
    chain = RunnablePassthrough.assign(
        context=mmr_retriever
    ).assign(
        answer=self.create_stuff_documents_chain(
            registry["qa_template"](system_prompt),
            preprompt_filter=None,   # or whatever
        )
        # but patch the chain to use the mmr_chat_model
        # easiest is to do something like .replace_llm(mmr_chat_model) or build a new chain
    )

    return chain.with_config(run_name="mmr_chain")
Sub-chain B: similarity retrieval, plus LLM with some user-set temperature

python
Copy
Edit
def create_similarity_chain(self, system_prompt: str, user_temp: float = 0.8) -> Runnable:
    # 1) similarity retriever
    sim_retriever = self.vector_store.as_retriever(
        search_type="similarity",
        search_kwargs={"k": 4}
    )

    # 2) user-chosen temperature
    sim_chat_model = self.chat_model.bind(temperature=user_temp)

    # 3) chain
    chain = RunnablePassthrough.assign(
        context=sim_retriever
    ).assign(
        answer=self.create_stuff_documents_chain(
            registry["qa_template"](system_prompt),
            preprompt_filter=None,
        )
        # again, ensure that "stuff_documents_chain"
        # uses sim_chat_model instead of your default self.chat_model
    )

    return chain.with_config(run_name="similarity_chain")
In each of these sub-chains, you do the steps:

RunnablePassthrough.assign(context=retriever) so that the retriever’s list of docs goes under "context".
Then you call your “stuff documents” chain with a specialized LLM.
The chain produces {"answer": "...some text..."} at the end.
2) Combine them in parallel
Once you have these sub-chains, do something like:

python
Copy
Edit
from langchain_core.runnables import RunnableParallel

mmr_chain = self.create_mmr_chain(system_prompt, mmr_temp=0.0)
sim_chain = self.create_similarity_chain(system_prompt, user_temp=0.8)

parallel = RunnableParallel({
    "mmr_result": mmr_chain,
    "similarity_result": sim_chain,
})
When you call:

python
Copy
Edit
final_outputs = parallel.invoke({"input": user_input})
you’ll get a dictionary:

python
Copy
Edit
{
  "mmr_result": {
    "context": [...],   # if sub-chain retained that
    "answer": "Here’s the MMR-based answer"
  },
  "similarity_result": {
    "context": [...],
    "answer": "Here’s the similarity-based answer"
  }
}
Each sub-chain runs in parallel. If you do .astream(...), you can get partial outputs from each chain interleaved.

3) Handling user input / system prompts
You might want each chain to see the same user input. That’s easy – you pass the same dictionary to parallel.invoke(...). If each chain needs a system_prompt, you can either:

Hardcode the system prompt in each sub-chain, or
Store it in the input dictionary under some key, then read it in each sub-chain or prompt.
4) If you want them both to store an AIMessage in DB
By default, RunnableWithMessageHistory only logs a single AI response to your history. If you want to store both sub-chains’ final AI messages, you’ll need to customize _exit_history or produce a final list of messages in your chain’s final dictionary. Something like:

python
Copy
Edit
final_outputs = parallel.invoke({...})
# final_outputs = {
#   "mmr_result": {"answer": "..."},
#   "similarity_result": {"answer": "..."},
# }

return {
  "multi_answers": [
    AIMessage(content=final_outputs["mmr_result"]["answer"]),
    AIMessage(content=final_outputs["similarity_result"]["answer"]),
  ]
}
so _exit_history sees multiple AI messages. Or you might keep it simpler, storing only one. That’s up to you.

Putting it all together (pseudo-code)
python
Copy
Edit
async def multi_response_astream(self, user_input: str) -> AsyncIterator[dict]:
    # Suppose we have a system prompt
    system_prompt = "You are an advanced question-answering system."

    # Build each sub-chain
    mmr_chain = self.create_mmr_chain(system_prompt, mmr_temp=0.0)
    sim_chain = self.create_similarity_chain(system_prompt, user_temp=0.8)

    # Wrap them in parallel
    parallel = RunnableParallel({
        "mmr_result": mmr_chain,
        "similarity_result": sim_chain,
    })

    # Possibly also wrap it with RunnableWithMessageHistory if you want partial storage
    # or store them manually
    async for partial_output in parallel.astream({"input": user_input}):
        yield partial_output
At the end, you get two parallel sub-chains, each retrieving docs differently and using different LLM temperatures. They run concurrently, returning two final answers in a dictionary.


