In [None]:
from langchain.prompts import (
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
    MessagesPlaceholder,
    ChatPromptTemplate,
)
from langchain_core.chat_history import (
    BaseChatMessageHistory,
    InMemoryChatMessageHistory,
)
from langchain_core.messages import BaseMessage, SystemMessage
from langchain_core.runnables import ConfigurableFieldSpec
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_ollama.llms import OllamaLLM
from pydantic import BaseModel, Field

In [None]:
llm = OllamaLLM(model="deepseek-r1:8b", reasoning=False)

# 1. Message History

In [None]:
system_prompt = "You are a helpful assistant called Zeta."

prompt_template = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate.from_template(system_prompt),
        MessagesPlaceholder(variable_name="history"),
        HumanMessagePromptTemplate.from_template("{query}"),
    ]
)

pipeline = prompt_template | llm

In [None]:
chat_map = {}


def get_chat_history(session_id: str) -> InMemoryChatMessageHistory:
    if session_id not in chat_map:
        # if session ID doesn't exist, create a new chat history
        chat_map[session_id] = InMemoryChatMessageHistory()
    return chat_map[session_id]

In [None]:
pipeline_with_history = RunnableWithMessageHistory(
    pipeline,
    get_session_history=get_chat_history,
    input_messages_key="query",
    history_messages_key="history",
)

# Invoke
msgs = ["Hi, my name is James", "What is my name again?"]
for i, msg in enumerate(msgs):
    print(f"---\nMessage {i + 1}\n---\n")
    response = pipeline_with_history.invoke(
        {"query": msg}, config={"session_id": "id_123"}
    )
    print(response)

# 2. Message History + Window

In [None]:
class BufferWindowMessageHistory(BaseChatMessageHistory, BaseModel):
    messages: list[BaseMessage] = Field(default_factory=list)
    k: int = Field(default_factory=int)

    def __init__(self, k: int):
        super().__init__(k=k)
        print(f"Initializing BufferWindowMessageHistory with k={k}")

    def add_messages(self, messages: list[BaseMessage]) -> None:
        """Add messages to the history, removing any messages beyond the last `k` messages."""
        self.messages.extend(messages)
        self.messages = self.messages[-self.k :]

    def clear(self) -> None:
        """Clear the history."""
        self.messages = []

In [None]:
chat_map = {}


def get_chat_history(session_id: str, k: int = 4) -> BufferWindowMessageHistory:
    print(f"get_chat_history called with session_id={session_id} and k={k}")
    if session_id not in chat_map:
        # if session ID doesn't exist, create a new chat history
        chat_map[session_id] = BufferWindowMessageHistory(k=k)
    # remove anything beyond the last
    return chat_map[session_id]

In [None]:
pipeline_with_history = RunnableWithMessageHistory(
    pipeline,
    get_session_history=get_chat_history,
    input_messages_key="query",
    history_messages_key="history",
    history_factory_config=[
        ConfigurableFieldSpec(
            id="session_id",
            annotation=str,
            name="Session ID",
            description="The session ID to use for the chat history",
            default="id_default",
        ),
        ConfigurableFieldSpec(
            id="k",
            annotation=int,
            name="k",
            description="The number of messages to keep in the history",
            default=4,
        ),
    ],
)

# Invoke
msgs = [
    "Hi, my name is James",
    "I like go to the swimming pool.",
    "I'm researching the different types of conversational memory.",
    "What is my name again?",
]
for i, msg in enumerate(msgs):
    print(f"---\nMessage {i + 1}\n---\n")
    response = pipeline_with_history.invoke(
        {"query": msg}, config={"configurable": {"session_id": "id_k2", "k": 2}}
    )
    print(response)

# 3. Message History + Summary 

In [None]:
class ConversationSummaryMessageHistory(BaseChatMessageHistory, BaseModel):
    messages: list[BaseMessage] = Field(default_factory=list)
    llm: OllamaLLM = Field(default_factory=OllamaLLM)

    def __init__(self, llm: OllamaLLM):
        super().__init__(llm=llm)

    def add_messages(self, messages: list[BaseMessage]) -> None:
        self.messages.extend(messages)
        # construct the summary chat messages
        summary_prompt = ChatPromptTemplate.from_messages(
            [
                SystemMessagePromptTemplate.from_template(
                    "Given the existing conversation summary and the new messages, "
                    "generate a new summary of the conversation. Ensuring to maintain "
                    "as much relevant information as possible."
                ),
                HumanMessagePromptTemplate.from_template(
                    "Existing conversation summary:\n{existing_summary}\n\n"
                    "New messages:\n{messages}"
                ),
            ]
        )
        # format the messages and invoke the LLM
        new_summary = self.llm.invoke(
            summary_prompt.format_messages(
                existing_summary=self.messages, messages=messages
            )
        )
        # replace the existing history with a single system summary message
        self.messages = [SystemMessage(content=new_summary)]

    def clear(self) -> None:
        """Clear the history."""
        self.messages = []

In [None]:
chat_map = {}


def get_chat_history(
    session_id: str, llm: OllamaLLM
) -> ConversationSummaryMessageHistory:
    if session_id not in chat_map:
        # if session ID doesn't exist, create a new chat history
        chat_map[session_id] = ConversationSummaryMessageHistory(llm=llm)
    # return the chat history
    return chat_map[session_id]

In [None]:
pipeline_with_history = RunnableWithMessageHistory(
    pipeline,
    get_session_history=get_chat_history,
    input_messages_key="query",
    history_messages_key="history",
    history_factory_config=[
        ConfigurableFieldSpec(
            id="session_id",
            annotation=str,
            name="Session ID",
            description="The session ID to use for the chat history",
            default="id_default",
        ),
        ConfigurableFieldSpec(
            id="llm",
            annotation=OllamaLLM,
            name="LLM",
            description="The LLM to use for the conversation summary",
            default=llm,
        ),
    ],
)

# Invoke
msgs = [
    "Hi, my name is James",
    "I have been looking at ConversationBufferMemory and ConversationBufferWindowMemory.",
    "Buffer memory just stores the entire conversation",
    "Buffer window memory stores the last k messages, dropping the rest.",
]
for i, msg in enumerate(msgs):
    print(f"---\nMessage {i + 1}\n---\n")
    response = pipeline_with_history.invoke(
        {"query": msg}, config={"session_id": "id_123", "llm": llm}
    )
    print(response)
    print("\nMessages in history:")
    print(chat_map["id_123"].messages)

# 4. Message History + Window + Summary 

In [None]:
class ConversationSummaryBufferMessageHistory(BaseChatMessageHistory, BaseModel):
    messages: list[BaseMessage] = Field(default_factory=list)
    llm: OllamaLLM = Field(default_factory=OllamaLLM)
    k: int = Field(default_factory=int)

    def __init__(self, llm: OllamaLLM, k: int):
        super().__init__(llm=llm, k=k)

    def add_messages(self, messages: list[BaseMessage]) -> None:
        """Add messages to the history, removing any messages beyond
        the last `k` messages and summarizing the messages that we
        drop.
        """
        existing_summary: SystemMessage | None = None
        old_messages: list[BaseMessage] | None = None

        # see if we already have a summary message
        if len(self.messages) > 0 and isinstance(self.messages[0], SystemMessage):
            print(">> Found existing summary")
            existing_summary = self.messages.pop(0)

        # add the new messages to the history
        self.messages.extend(messages)

        # check if we have too many messages
        if len(self.messages) > self.k:
            print(
                f">> Found {len(self.messages)} messages, dropping "
                f"oldest {len(self.messages) - self.k} messages."
            )
            # pull out the oldest messages...
            old_messages = self.messages[: self.k]
            # ...and keep only the most recent messages
            self.messages = self.messages[-self.k :]

        if old_messages is None:
            print(">> No old messages to update summary with")
            # if we have no old_messages, we have nothing to update in summary
            return

        # construct the summary chat messages
        summary_prompt = ChatPromptTemplate.from_messages(
            [
                SystemMessagePromptTemplate.from_template(
                    "Given the existing conversation summary and the new messages, "
                    "generate a new summary of the conversation. Ensuring to maintain "
                    "as much relevant information as possible."
                ),
                HumanMessagePromptTemplate.from_template(
                    "Existing conversation summary:\n{existing_summary}\n\n"
                    "New messages:\n{old_messages}"
                ),
            ]
        )

        # format the messages and invoke the LLM
        new_summary = self.llm.invoke(
            summary_prompt.format_messages(
                existing_summary=existing_summary, old_messages=old_messages
            )
        )
        print(f">> New summary: {new_summary}")
        # prepend the new summary to the history
        self.messages = [SystemMessage(content=new_summary)] + self.messages

    def clear(self) -> None:
        """Clear the history."""
        self.messages = []

In [None]:
chat_map = {}


def get_chat_history(
    session_id: str, llm: OllamaLLM, k: int
) -> ConversationSummaryBufferMessageHistory:
    if session_id not in chat_map:
        # if session ID doesn't exist, create a new chat history
        chat_map[session_id] = ConversationSummaryBufferMessageHistory(llm=llm, k=k)
    # return the chat history
    return chat_map[session_id]

In [None]:
pipeline_with_history = RunnableWithMessageHistory(
    pipeline,
    get_session_history=get_chat_history,
    input_messages_key="query",
    history_messages_key="history",
    history_factory_config=[
        ConfigurableFieldSpec(
            id="session_id",
            annotation=str,
            name="Session ID",
            description="The session ID to use for the chat history",
            default="id_default",
        ),
        ConfigurableFieldSpec(
            id="llm",
            annotation=OllamaLLM,
            name="LLM",
            description="The LLM to use for the conversation summary",
            default=llm,
        ),
        ConfigurableFieldSpec(
            id="k",
            annotation=int,
            name="k",
            description="The number of messages to keep in the history",
            default=4,
        ),
    ],
)

# Invoke
msgs = [
    "Hi, my name is James",
    "I have been looking at ConversationBufferMemory and ConversationBufferWindowMemory.",
    "Buffer memory just stores the entire conversation",
    "Buffer window memory stores the last k messages, dropping the rest.",
]
for i, msg in enumerate(msgs):
    print(f"---\nMessage {i + 1}\n---\n")
    response = pipeline_with_history.invoke(
        {"query": msg}, config={"session_id": "id_123", "llm": llm, "k": 4}
    )