In [None]:
from typing import List, Annotated
from langgraph.graph import StateGraph, MessagesState, add_messages, START, END
from langchain_core.messages import (
    HumanMessage, 
    AIMessage, 
    SystemMessage,
    AnyMessage,
    RemoveMessage,
    trim_messages
)
from langchain_openai import ChatOpenAI
from IPython.display import Image, display
from dotenv import load_dotenv
import os

load_dotenv()
api_key = os.getenv("API_KEY")
base_url = os.getenv("OPENAI_ENDPOINT")
model_name = "gpt-4o-mini"
temp=0.0

llm = ChatOpenAI(
    base_url=base_url,
    api_key=api_key,
    model=model_name,
    temperature=temp
)

**Single Node Workflow**

In [None]:
from dotenv import load_dotenv
load_dotenv()
llm = ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0.0,
)

In [None]:
def llm_node(state: MessagesState):
    return {"messages": llm.invoke(state["messages"])}

In [None]:
workflow = StateGraph(MessagesState)
workflow.add_node("llm_node", llm_node)
workflow.add_edge(START, "llm_node")
workflow.add_edge("llm_node", END)
graph = workflow.compile()

display(
    Image(
        graph.get_graph().draw_mermaid_png()
    )
)

**Useful Messages List**

In [None]:
messages = [
    SystemMessage(
        content="You're a FinTech specialist. You're not allowed to talk about anything else. Be concise in your answers.",
        name="System",
        id="0",
    ),
    HumanMessage(
        content="What is Pokemon?",
        name="User",
        id="1"
    ),
    AIMessage(
        content="I'm here to provide information specifically about FinTech. If you have " 
                "any questions related to financial technology, such as digital payments, " 
                "blockchain, cryptocurrencies, or financial services innovations, feel free " 
                "to ask!",
        name="FintechAssistant",
        id="2",
    ),
    HumanMessage(
        content="What is BlockChain?",
        name="User",
        id="3"
    ),
    AIMessage(
        content="Blockchain is a decentralized digital ledger technology that records" 
                "transactions across multiple computers in a way that ensures the security, " 
                "transparency, and integrity of the data. Each transaction is grouped into " 
                "a block, and these blocks are linked together in chronological order to form " 
                "a chain, hence the name blockchain.",
        name="FintechAssistant",
        id="4",
    ),
    HumanMessage(
        content="What is a credit card fraud?",
        name="User",
        id="5"
    ),
]

**Filter Messages during Invocation**

In [None]:
complete_output = graph.invoke({"messages": messages})

In [None]:
complete_output["messages"][-1].pretty_print()

In [None]:
complete_output["messages"][-1].response_metadata["token_usage"]

In [None]:
filtered_output = graph.invoke({"messages": [messages[0], messages[-1]]})

In [None]:
filtered_output["messages"][-1].pretty_print()

In [None]:
filtered_output["messages"][-1].response_metadata["token_usage"]

**Filter Messages inside a node**

In [None]:
class State(MessagesState):
    filtered_messages: Annotated[List[AnyMessage], add_messages]

In [None]:
def llm_node(state: State):
    filtered_messages = state["messages"][-3:]
    ai_message = llm.invoke(filtered_messages)
    filtered_messages.append(ai_message)
    return {"messages": ai_message, "filtered_messages": filtered_messages}

In [None]:
workflow = StateGraph(State)
workflow.add_node("llm_node", llm_node)
workflow.add_edge(START, "llm_node")
workflow.add_edge("llm_node", END)
graph = workflow.compile()

display(
    Image(
        graph.get_graph().draw_mermaid_png()
    )
)

In [None]:
for m in messages:
    m.pretty_print()

In [None]:
output = graph.invoke({'messages': messages})
for m in output['filtered_messages']:
    m.pretty_print()

**Remove Messages**

In [None]:
messages[:-3]

In [None]:
delete_messages = [RemoveMessage(id=m.id) for m in messages[:-3]]
add_messages(messages, delete_messages)

In [None]:
class State(MessagesState):
    filtered_messages: Annotated[List[AnyMessage], add_messages]

In [None]:
def removal_filter(state: State):
    filtered_messages = [
        RemoveMessage(id=m.id) 
            for m in state["messages"][:-3] 
            if m.name != "System"
    ]
    return {
        "filtered_messages": add_messages(
            state["messages"], 
            filtered_messages
        )
    }

In [None]:
def llm_node(state: State):
    ai_message = llm.invoke(state["filtered_messages"])
    return {
        "filtered_messages": ai_message,
    }

In [None]:
workflow = StateGraph(State)
workflow.add_node("llm_node", llm_node)
workflow.add_node("removal_filter", removal_filter)
workflow.add_edge(START, "removal_filter")
workflow.add_edge("removal_filter", "llm_node")
workflow.add_edge("llm_node", END)

graph = workflow.compile()

display(
    Image(
        graph.get_graph().draw_mermaid_png()
    )
)

In [None]:
for m in messages:
    m.pretty_print()

In [None]:
output = graph.invoke({'messages': messages})
for m in output['filtered_messages']:
    m.pretty_print()

**Trim Messages**

In [None]:
trim_messages(
    messages,
    max_tokens=30,
    strategy="last",
    token_counter=llm,
    allow_partial=False,
    include_system=True
)

In [None]:
class State(MessagesState):
    max_tokens: int
    filtered_messages: Annotated[List[AnyMessage], add_messages]

In [None]:
def trim_filter(state: State):
    max_tokens = state["max_tokens"]
    messages = state["messages"]
    filtered_messages = messages
    if max_tokens:
        filtered_messages = trim_messages(
            messages=messages,
            max_tokens=max_tokens,
            strategy="last",
            token_counter=llm,
            include_system=True,
            allow_partial=False
        )
    return {"filtered_messages": filtered_messages}

In [None]:
def llm_node(state: State):
    return {"filtered_messages": llm.invoke(state["filtered_messages"])}

In [None]:
workflow = StateGraph(State)
workflow.add_node("llm_node", llm_node)
workflow.add_node("trim_filter", trim_filter)
workflow.add_edge(START, "trim_filter")
workflow.add_edge("trim_filter", "llm_node")
workflow.add_edge("llm_node", END)

graph = workflow.compile()

display(
    Image(
        graph.get_graph().draw_mermaid_png()
    )
)

In [None]:
for m in messages:
    m.pretty_print()

In [None]:
output = graph.invoke(
    input={
        "max_tokens": 50,
        "messages": messages
    }
)

In [None]:
for m in output['filtered_messages']:
    m.pretty_print()

**Summary**

In [None]:
messages[1:-1]

In [None]:
messages_to_summarize = messages[1:-1]
summary_message = HumanMessage(
    content="Create a summary of the conversation above:", 
    name="User"
)
ai_message = llm.invoke(
    add_messages(
        messages_to_summarize,
        summary_message
    )
)

In [None]:
ai_message.content

In [None]:
ai_message.id = "1"
messages[-1].id = "2"

In [None]:
remaining_messages = [messages[0]] + [ai_message] + [messages[-1]]

In [None]:
remaining_messages

In [None]:
remaining_messages.append(llm.invoke(remaining_messages))

In [None]:
for m in remaining_messages:
    m.pretty_print()

In [None]:
remaining_messages[-1].response_metadata["token_usage"]