In [1]:
from dotenv import load_dotenv

load_dotenv()

True

In [2]:
import psycopg2
from contextlib import contextmanager
from typing import Optional, Iterator

from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.checkpoint import SerializerProtocol
from langgraph.checkpoint.base import (

    SerializerProtocol,
)

class PostgresSaver(SqliteSaver):
    """A checkpoint saver that stores checkpoints in a PostgreSQL database."""

    def __init__(
        self,
        conn: psycopg2.extensions.connection,
        *,
        serde: Optional[SerializerProtocol] = None,
    ) -> None:
        super().__init__(conn=conn, serde=serde)

    @classmethod
    def from_conn_string(cls, conn_string: str) -> "PostgresSaver":
        """Create a new PostgresSaver instance from a connection string.

        Args:
            conn_string (str): The PostgreSQL connection string.

        Returns:
            PostgresSaver: A new PostgresSaver instance.
        """
        return PostgresSaver(
            conn=psycopg2.connect(conn_string)
        )

    def setup(self) -> None:
        """Set up the checkpoint database.

        This method creates the necessary tables in the PostgreSQL database if they don't
        already exist. It is called automatically when needed and should not be called
        directly by the user.
        """
        if self.is_setup:
            return

        with self.conn.cursor() as cur:
            cur.execute(
                """
                CREATE TABLE IF NOT EXISTS checkpoints (
                    thread_id TEXT NOT NULL,
                    thread_ts TEXT NOT NULL,
                    parent_ts TEXT,
                    checkpoint BYTEA,
                    metadata BYTEA,
                    PRIMARY KEY (thread_id, thread_ts)
                );
                """
            )
            self.conn.commit()

        self.is_setup = True

    @contextmanager
    def cursor(self, transaction: bool = True) -> Iterator[psycopg2.extensions.cursor]:
        """Get a cursor for the PostgreSQL database.

        This method returns a cursor for the PostgreSQL database. It is used internally
        by the PostgresSaver and should not be called directly by the user.

        Args:
            transaction (bool): Whether to commit the transaction when the cursor is closed. Defaults to True.

        Yields:
            psycopg2.extensions.cursor: A cursor for the PostgreSQL database.
        """
        self.setup()
        cur = self.conn.cursor()
        try:
            yield cur
        finally:
            if transaction:
                self.conn.commit()
            cur.close()

    def put(
        self,
        config: dict,
        checkpoint: dict,
        metadata: dict,
    ) -> dict:
        """Save a checkpoint to the database.

        This method saves a checkpoint to the PostgreSQL database. The checkpoint is associated
        with the provided config and its parent config (if any).

        Args:
            config (dict): The config to associate with the checkpoint.
            checkpoint (dict): The checkpoint to save.
            metadata (Optional[dict[str, Any]]): Additional metadata to save with the checkpoint. Defaults to None.

        Returns:
            dict: The updated config containing the saved checkpoint's timestamp.
        """
        with self.lock, self.cursor() as cur:
            cur.execute(
                """
                INSERT INTO checkpoints (thread_id, thread_ts, parent_ts, checkpoint, metadata)
                VALUES (%s, %s, %s, %s, %s)
                ON CONFLICT (thread_id, thread_ts)
                DO UPDATE SET checkpoint = EXCLUDED.checkpoint, metadata = EXCLUDED.metadata
                """,
                (
                    str(config["configurable"]["thread_id"]),
                    checkpoint["ts"],
                    config["configurable"].get("thread_ts"),
                    self.serde.dumps(checkpoint),
                    self.serde.dumps(metadata),
                ),
            )
        return {
            "configurable": {
                "thread_id": config["configurable"]["thread_id"],
                "thread_ts": checkpoint["ts"],
            }
        }

    def get(self, config: dict) -> Optional[dict]:
        """Retrieve the latest checkpoint for a given configuration.

        Args:
            config (dict): The configuration for which to retrieve the checkpoint.

        Returns:
            Optional[dict]: The latest checkpoint or None if no checkpoint exists.
        """
        with self.lock, self.cursor() as cur:
            cur.execute(
                """
                SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata
                FROM checkpoints
                WHERE thread_id = %s
                ORDER BY thread_ts DESC
                LIMIT 1
                """,
                (str(config["configurable"]["thread_id"]),),
            )
            row = cur.fetchone()
            if row:
                return self.serde.loads(row[3])
        return None

In [3]:
from langchain_core.messages import HumanMessage
from langgraph.graph import END, MessageGraph


def add_one(input: list[HumanMessage]):
    input[0].content = input[0].content + "a"
    return input

graph = MessageGraph()

graph.add_node("branch_a", add_one)
graph.add_edge("branch_a", "branch_b")
graph.add_edge("branch_a", "branch_c")

graph.add_node("branch_b", add_one)
graph.add_node("branch_c", add_one)

graph.add_edge("branch_b", "final_node")
graph.add_edge("branch_c", "final_node")

graph.add_node("final_node", add_one)
graph.add_edge("final_node", END)

graph.set_entry_point("branch_a")

In [4]:
conn_string = "dbname=mydatabase user=myuser password=mypassword host=localhost port=5433"
memory = PostgresSaver.from_conn_string(conn_string)

runnable = graph.compile(checkpointer=memory)

In [5]:
config = {"configurable": {"thread_id": "2"}}

runnable.invoke("a", config=config)

SyntaxError: syntax error at or near "ORDER"
LINE 1: ...nt, metadata FROM checkpoints WHERE thread_id = ? ORDER BY t...
                                                             ^
