# Extra Practice: State Memory

## Build a small graph

This is a small simple graph you can tinker with if you want more insight into controlling state memory.

In [1]:
from dotenv import load_dotenv

_ = load_dotenv()

In [2]:
from typing import TypedDict, Annotated
import operator

from langgraph.graph import StateGraph, END
from langgraph.checkpoint.sqlite import SqliteSaver

Define a simple 2 node graph with the following state:

- `lnode`: last node
- `scratch`: a scratchpad location
- `count` : a counter that is incremented each step

In [3]:
class AgentState(TypedDict):
    lnode: str
    scratch: str
    count: Annotated[int, operator.add]

In [4]:
def node1(state: AgentState):
    print(f"node1, count:{state['count']}")
    return {
        "lnode": "node_1",
        "count": 1,
    }


def node2(state: AgentState):
    print(f"node2, count:{state['count']}")
    return {
        "lnode": "node_2",
        "count": 1,
    }

The graph goes N1->N2->N1... but breaks after count reaches 3.

In [5]:
def should_continue(state: AgentState):
    return state["count"] < 3

In [6]:
builder = StateGraph(AgentState)
builder.add_node("Node1", node1)
builder.add_node("Node2", node2)

builder.add_edge("Node1", "Node2")
builder.add_conditional_edges("Node2", should_continue, {True: "Node1", False: END})
builder.set_entry_point("Node1")

In [7]:
memory = SqliteSaver.from_conn_string(":memory:")
graph = builder.compile(checkpointer=memory)

In [8]:
thread = {"configurable": {"thread_id": str(1)}}
graph.invoke({"count": 0, "scratch": "hi"}, thread)

node1, count:0
node2, count:1
node1, count:2
node2, count:3


{'lnode': 'node_2', 'scratch': 'hi', 'count': 4}

In [9]:
print(graph.get_state(thread))

StateSnapshot(values={'lnode': 'node_2', 'scratch': 'hi', 'count': 4}, next=(), config={'configurable': {'thread_id': '1', 'thread_ts': '1ef3ecbc-7957-6bd6-8004-6bbe42659353'}}, metadata={'source': 'loop', 'step': 4, 'writes': {'Node2': {'lnode': 'node_2', 'count': 1}}}, created_at='2024-07-10T14:50:41.952042+00:00', parent_config={'configurable': {'thread_id': '1', 'thread_ts': '1ef3ecbc-7955-6016-8003-54f1f64e32f7'}})


In [10]:
for state in graph.get_state_history(thread):
    print(state, "\n")

StateSnapshot(values={'lnode': 'node_2', 'scratch': 'hi', 'count': 4}, next=(), config={'configurable': {'thread_id': '1', 'thread_ts': '1ef3ecbc-7957-6bd6-8004-6bbe42659353'}}, metadata={'source': 'loop', 'step': 4, 'writes': {'Node2': {'lnode': 'node_2', 'count': 1}}}, created_at='2024-07-10T14:50:41.952042+00:00', parent_config={'configurable': {'thread_id': '1', 'thread_ts': '1ef3ecbc-7955-6016-8003-54f1f64e32f7'}}) 

StateSnapshot(values={'lnode': 'node_1', 'scratch': 'hi', 'count': 3}, next=('Node2',), config={'configurable': {'thread_id': '1', 'thread_ts': '1ef3ecbc-7955-6016-8003-54f1f64e32f7'}}, metadata={'source': 'loop', 'step': 3, 'writes': {'Node1': {'lnode': 'node_1', 'count': 1}}}, created_at='2024-07-10T14:50:41.950917+00:00', parent_config={'configurable': {'thread_id': '1', 'thread_ts': '1ef3ecbc-7952-6cda-8002-2465676b2268'}}) 

StateSnapshot(values={'lnode': 'node_2', 'scratch': 'hi', 'count': 2}, next=('Node1',), config={'configurable': {'thread_id': '1', 'thread_t