In [1]:
import os
import weaviate
from weaviate.classes.init import Auth

# Best practice: store your credentials in environment variables
weaviate_url = os.environ["WEAVIATE_URL"]
weaviate_api_key = os.environ["WEAVIATE_API_KEY"]

# Connect to Weaviate Cloud
client = weaviate.connect_to_weaviate_cloud(
    cluster_url=weaviate_url,
    auth_credentials=Auth.api_key(weaviate_api_key),
)

print(client.is_ready())

True


In [4]:
weaviate_url = os.environ.get("WEAVIATE_URL")
weaviate_api_key = os.environ.get("WEAVIATE_API_KEY")
import weaviate
from weaviate import Client
from weaviate.classes.init import Auth
from weaviate.classes.query import Filter

from langchain_weaviate.vectorstores import WeaviateVectorStore

In [5]:
from langchain_core.embeddings import Embeddings
from langchain_huggingface import HuggingFaceEmbeddings
import os
from dotenv import load_dotenv
load_dotenv()
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')

def get_embeddings_model() -> Embeddings:
    return HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") 

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
store = WeaviateVectorStore(
            client=client,
            index_name="pdf_index",
            text_key="text",
            embedding=get_embeddings_model(),
            attributes=["source", "title"],
        )

In [11]:
ret = store.as_retriever()


In [12]:
res = ret.invoke("what is transformer")

In [13]:
res

[Document(metadata={'keywords': '', 'creator': 'LaTeX with hyperref', 'file_path': 'C:\\Users\\amins\\AppData\\Local\\Temp\\ingest_74749f49-8a8b-4c2a-bb40-c934519a72f8_1y378oaf\\Attention.pdf', 'trapped': '/False', 'page_label': '3', 'creationdate': datetime.datetime(2024, 4, 10, 21, 11, 43, tzinfo=datetime.timezone.utc), 'file_name': 'Attention.pdf', 'page': 2.0, 'ptex_fullbanner': 'This is pdfTeX, Version 3.141592653-2.6-1.40.25 (TeX Live 2023) kpathsea version 6.3.5', 'subject': '', 'directory': 'C:\\Users\\amins\\AppData\\Local\\Temp\\ingest_74749f49-8a8b-4c2a-bb40-c934519a72f8_1y378oaf', 'source': 'C:/Users/amins/AppData/Local/Temp/ingest_74749f49-8a8b-4c2a-bb40-c934519a72f8_1y378oaf/Attention.pdf', 'total_pages': 15.0, 'producer': 'pdfTeX-1.40.25', 'moddate': datetime.datetime(2024, 4, 10, 21, 11, 43, tzinfo=datetime.timezone.utc), 'title': '', 'author': ''}, page_content='Figure 1: The Transformer - model architecture.\nThe Transformer follows this overall architecture using sta

In [1]:
import os
import uuid
from datetime import datetime, timezone
from typing import cast, Optional, Any
from contextlib import contextmanager
from dataclasses import field

from langchain_core.documents import Document
from langchain_core.messages import BaseMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph
from langgraph.checkpoint.postgres import PostgresSaver
from langgraph.store.postgres import PostgresStore
from pydantic import BaseModel

from backend import retrieval
from backend.configuration import Configuration, IndexConfiguration
from backend.state import InputState, State
from backend.utils import format_docs, get_message_text, load_chat_model

# Your ProdDBConfig class with user-specific configurations
class ProdDBConfig:
    @staticmethod
    def _build_uri() -> str:
        """Production URI with pooling, SSL, timeouts."""
        return (
            "postgresql://postgres:123456@localhost:5432/langgraphrag?"
            "sslmode=disable&"
            "connect_timeout=10"
        )
    
    @staticmethod
    def checkpointer() -> PostgresSaver:
        """Get checkpointer instance."""
        uri = ProdDBConfig._build_uri()
        return PostgresSaver.from_conn_string(uri)
    
    @staticmethod
    def store() -> PostgresStore:
        """Get store instance."""
        uri = ProdDBConfig._build_uri()
        return PostgresStore.from_conn_string(uri)

    @staticmethod
    @contextmanager
    def get_store_context():
        """Context manager for store operations."""
        store = ProdDBConfig.store()
        try:
            yield store
        finally:
            if hasattr(store, 'close'):
                store.close()

# Health check helper
def db_health_check():
    """Verify DB connectivity."""
    try:
        with ProdDBConfig.get_store_context() as store:
            print("Store connection successful")
        print("Database health check passed")
        return True
    except Exception as e:
        print(f"Database health check failed: {e}")
        return False

# Define the function that calls the model
class SearchQuery(BaseModel):
    """Search the indexed documents for a query."""
    query: str


  from .autonotebook import tqdm as notebook_tqdm


In [None]:

async def generate_query(
    state: State, *, config: RunnableConfig
) -> dict[str, list[str]]:
    """Generate a search query based on the current state and configuration."""
    messages = state.messages
    
    # Get configuration with user_id
    configuration = Configuration.from_runnable_config(config)
    
    if len(messages) == 1:
        human_input = get_message_text(messages[-1])
        return {"queries": [human_input]}
    else:
        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", configuration.query_system_prompt),
                ("placeholder", "{messages}"),
            ]
        )
        model = load_chat_model(configuration.query_model).with_structured_output(
            SearchQuery
        )

        message_value = await prompt.ainvoke(
            {
                "messages": state.messages,
                "queries": "\n- ".join(state.queries),
                "system_time": datetime.now(tz=timezone.utc).isoformat(),
                "user_id": configuration.user_id,  # Include user_id in context
            },
            config,
        )
        generated = cast(SearchQuery, await model.ainvoke(message_value, config))
        return {
            "queries": [generated.query],
        }

async def retrieve(
    state: State, *, config: RunnableConfig
) -> dict[str, list[Document]]:
    """Retrieve documents based on the latest query in the state."""
    # Get configuration including IndexConfiguration
    configuration = Configuration.from_runnable_config(config)
    user_id = configuration.user_id
    
    # Store user query in database for history
    with ProdDBConfig.get_store_context() as store:
        query_data = {
            "query": state.queries[-1],
            "user_id": user_id,
            "timestamp": datetime.now(tz=timezone.utc).isoformat(),
            "metadata": {
                "retriever_provider": configuration.retriever_provider,
                "embedding_model": configuration.embedding_model
            }
        }
        
        # Store query with user-specific key
        query_id = f"query_{user_id}_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}"
        store.put(query_id, query_data)
    
    # Use the retriever with user-specific filtering
    with retrieval.make_retriever(config) as retriever:
        # If your retriever supports filtering by metadata, apply user filter
        if hasattr(config, 'metadata_filter'):
            config.metadata_filter = {"user_id": user_id}
        
        # Also pass user_id in search kwargs if supported
        search_kwargs = configuration.search_kwargs.copy()
        search_kwargs["metadata_filter"] = {"user_id": user_id}
        
        response = await retriever.ainvoke(
            state.queries[-1], 
            config={**config, "search_kwargs": search_kwargs}
        )
        return {"retrieved_docs": response}

async def respond(
    state: State, *, config: RunnableConfig
) -> dict[str, list[BaseMessage]]:
    """Call the LLM powering our "agent"."""
    configuration = Configuration.from_runnable_config(config)
    user_id = configuration.user_id
    
    # Store conversation in the database with user filtering
    with ProdDBConfig.get_store_context() as store:
        # Store conversation metadata with user_id
        conversation_data = {
            "user_id": user_id,
            "messages": [msg.dict() for msg in state.messages],
            "queries": state.queries,
            "retrieved_docs": [doc.dict() for doc in state.retrieved_docs],
            "timestamp": datetime.now(tz=timezone.utc).isoformat(),
            "configuration": {
                "retriever_provider": configuration.retriever_provider,
                "embedding_model": configuration.embedding_model
            }
        }
        
        # Store with user-specific key
        conversation_id = f"conversation_{user_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        store.put(conversation_id, conversation_data)
    
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", configuration.response_system_prompt),
            ("placeholder", "{messages}"),
        ]
    )
    model = load_chat_model(configuration.response_model)

    retrieved_docs = format_docs(state.retrieved_docs)
    message_value = await prompt.ainvoke(
        {
            "messages": state.messages,
            "retrieved_docs": retrieved_docs,
            "system_time": datetime.now(tz=timezone.utc).isoformat(),
            "user_id": user_id,  # Include user_id in prompt context
        },
        config,
    )
    response = await model.ainvoke(message_value, config)
    
    # Store the response separately
    with ProdDBConfig.get_store_context() as store:
        response_data = {
            "user_id": user_id,
            "query": state.queries[-1] if state.queries else "",
            "response": response.content if hasattr(response, 'content') else str(response),
            "timestamp": datetime.now(tz=timezone.utc).isoformat(),
            "conversation_id": conversation_id
        }
        response_key = f"response_{user_id}_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}"
        store.put(response_key, response_data)
    
    return {"messages": [response]}

# Define a new graph
builder = StateGraph(State, input_schema=InputState, context_schema=Configuration)

builder.add_node(generate_query)
builder.add_node(retrieve)
builder.add_node(respond)
builder.add_edge("__start__", "generate_query")
builder.add_edge("generate_query", "retrieve")
builder.add_edge("retrieve", "respond")

# Finally, we compile it with checkpointing!
checkpointer = ProdDBConfig.checkpointer()

graph = builder.compile(
    checkpointer=checkpointer,
    interrupt_before=[],
    interrupt_after=[],
)
graph.name = "RetrievalGraph"



TypeError: StateGraph.compile() got an unexpected keyword argument 'config_schema'

In [None]:

# Enhanced functions for user-specific operations
def load_user_conversation(thread_id: str, user_id: str, config: Optional[dict] = None):
    """Load a conversation from checkpoint by thread ID for specific user."""
    configurable = {
        "thread_id": thread_id,
        "user_id": user_id
    }
    if config:
        configurable.update(config.get("configurable", {}))
    
    return graph.get_state({"configurable": configurable})

def list_user_conversations(user_id: str, limit: int = 100):
    """List all stored conversation thread IDs for a specific user."""
    with ProdDBConfig.checkpointer() as cp:
        # This will need adjustment based on your PostgresSaver implementation
        # You might need to implement custom filtering
        all_conversations = cp.list({"configurable": {}}, limit=limit)
        
        # Filter by user_id if stored in checkpoint metadata
        user_conversations = []
        for conv in all_conversations:
            if hasattr(conv, 'metadata') and conv.metadata.get('user_id') == user_id:
                user_conversations.append(conv)
        
        return user_conversations

def get_user_conversation_history(user_id: str, limit: int = 50):
    """Get conversation history from store for a specific user."""
    with ProdDBConfig.get_store_context() as store:
        # This assumes your store supports scanning or querying by prefix
        # Adjust based on your PostgresStore implementation
        conversations = []
        
        # Try to get conversations by pattern
        # Note: Actual implementation depends on your store's query capabilities
        try:
            # If store has scan or query capabilities
            if hasattr(store, 'scan'):
                for key, value in store.scan():
                    if key.startswith(f"conversation_{user_id}_"):
                        conversations.append({
                            "id": key,
                            "data": value,
                            "timestamp": value.get("timestamp", "")
                        })
            
            # Sort by timestamp
            conversations.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
            
            return conversations[:limit]
        except:
            # Fallback - store implementation might differ
            return []

In [None]:

# Example usage:
if __name__ == "__main__":
    # Health check
    if db_health_check():
        print("Database is ready")
        
        # Example: Invoke the graph with user-specific configuration
        user_id = "user_12345"  # Or use UUID
        config = {
            "configurable": {
                "thread_id": f"{user_id}_session_1",
                "user_id": user_id,
                "retriever_provider": "weaviate",
                "embedding_model": "all-MiniLM-L6-v2"
            }
        }
        
        # This will automatically filter and store data for this user
        result = graph.invoke(
            {"messages": [{"role": "user", "content": "Hello, how are you?"}]},
            config=config
        )
        
        print("Response:", result["messages"][-1].content)
        
        # Get user's conversation history
        history = get_user_conversation_history(user_id, limit=10)
        print(f"\nUser {user_id} has {len(history)} conversations")
        
        # Continue conversation with same user
        result2 = graph.invoke(
            {"messages": [{"role": "user", "content": "Tell me more about that"}]},
            config=config  # Same thread_id and user_id will load previous state
        )
        
        print("\nFollow-up response:", result2["messages"][-1].content)

In [None]:
"""Main entrypoint for the conversational retrieval graph with user memory.

This module defines the core structure and functionality of the conversational
retrieval graph with stateful user preferences via langgraph.json store.
Supports cross-thread memory (user prefs) + thread-scoped history.
"""

from datetime import datetime, timezone
from typing import Any, Annotated, TypedDict, cast, Sequence
from operator import add

from langchain_core.documents import Document
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
from langchain_core.pydantic_v1 import BaseModel, Field
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from pydantic import BaseModel as PydanticBaseModel

from backend import retrieval
from backend.configuration import Configuration
from backend.state import InputState, State  # Assume State updated below
from backend.utils import format_docs, get_message_text, load_chat_model

# User preference schema for tools
class UserInfo(PydanticBaseModel):
    """User preferences schema."""
    name: str = Field(..., description="User's full name")
    favorite_topics: list[str] = Field(default_factory=list, description="Favorite topics")
    preferred_format: str = Field(default="detailed", description="Response style: brief/detailed")

# Memory tools - access store via config["store"] (langgraph.json injection)
@tool
async def get_user_prefs(runtime: Any) -> str:
    """Get stored preferences for current user."""
    store = runtime.config.get("store")
    if not store:
        return "No store available."
    user_id = runtime.config["configurable"]["user_id"]
    namespace = ("users", user_id)
    item = await store.aget(namespace, "profile")
    if item and item.value:
        prefs = item.value
        return f"Name: {prefs.get('name', 'Unknown')}\nTopics: {prefs.get('favorite_topics', [])}\nFormat: {prefs.get('preferred_format', 'detailed')}"
    return "No preferences stored yet."

@tool
async def save_user_prefs(info: UserInfo, runtime: Any) -> str:
    """Save/update user preferences."""
    store = runtime.config.get("store")
    if not store:
        return "No store available."
    user_id = runtime.config["configurable"]["user_id"]
    namespace = ("users", user_id)
    await store.aput(namespace, "profile", info.model_dump())
    return "Preferences saved!"

class SearchQuery(PydanticBaseModel):
    """Search the indexed documents for a query."""
    query: str

async def generate_query(
    state: State, *, config: RunnableConfig
) -> dict[str, list[str]]:
    """Generate a search query based on the current state and configuration."""
    messages = state.messages
    if len(messages) == 1:
        human_input = get_message_text(messages[-1])
        return {"queries": [human_input]}
    else:
        configuration = Configuration.from_runnable_config(config)
        prompt = ChatPromptTemplate.from_messages([
            ("system", configuration.query_system_prompt),
            ("placeholder", "{messages}"),
        ])
        model = load_chat_model(configuration.query_model).with_structured_output(SearchQuery)

        message_value = await prompt.ainvoke({
            "messages": state.messages,
            "queries": "\n- ".join(state.queries),
            "system_time": datetime.now(tz=timezone.utc).isoformat(),
        }, config)
        generated = cast(SearchQuery, await model.ainvoke(message_value, config))
        return {"queries": [generated.query]}

async def retrieve(
    state: State, *, config: RunnableConfig
) -> dict[str, list[Document]]:
    """Retrieve documents based on the latest query in the state."""
    with retrieval.make_retriever(config) as retriever:
        response = await retriever.ainvoke(state.queries[-1], config)
        return {"retrieved_docs": response}

async def respond(
    state: State, *, config: RunnableConfig
) -> dict[str, Sequence[BaseMessage]]:
    """Enhanced respond with user memory injection + tools."""
    store = config.get("store")  # langgraph.json injects AsyncPostgresStore
    configuration = Configuration.from_runnable_config(config)
    
    # Load user prefs from store
    user_id = config["configurable"]["user_id"]
    namespace = ("users", user_id)
    profile_item = await store.aget(namespace, "profile") if store else None
    
    user_context = ""
    if profile_item and profile_item.value:
        prefs = profile_item.value
        user_context = (
            f"User preferences: Name={prefs.get('name', 'Unknown')}, "
            f"Topics={prefs.get('favorite_topics', [])}, "
            f"Format={prefs.get('preferred_format', 'detailed')}.\n"
            f"Adapt responses accordingly."
        )
    
    # Enhanced prompt with user context
    prompt = ChatPromptTemplate.from_messages([
        ("system", f"{configuration.response_system_prompt}\n\n{user_context}"),
        ("placeholder", "{messages}"),
    ])
    
    model = load_chat_model(configuration.response_model)
    # Bind memory tools
    model_with_tools = model.bind_tools([get_user_prefs, save_user_prefs])
    
    retrieved_docs = format_docs(state.retrieved_docs)
    message_value = await prompt.ainvoke({
        "messages": state.messages,
        "retrieved_docs": retrieved_docs,
        "system_time": datetime.now(tz=timezone.utc).isoformat(),
    }, config)
    
    # Invoke with store access for tools
    response = await model_with_tools.ainvoke(message_value, config)
    return {"messages": [response]}

# Updated State (add annotations if needed)
class RetrievalState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]
    queries: Annotated[list[str], add]
    retrieved_docs: list[Document]

# Build graph
builder = StateGraph(RetrievalState, input_schema=InputState)  # Use your InputState
builder.add_node("generate_query", generate_query)
builder.add_node("retrieve", retrieve)
builder.add_node("respond", respond)

builder.add_edge("__start__", "generate_query")
builder.add_edge("generate_query", "retrieve")
builder.add_edge("retrieve", "respond")
builder.add_edge("respond", END)

# NO manual store/checkpointer - langgraph.json handles it!
graph = builder.compile(
    interrupt_before=[],
    interrupt_after=[],
)
graph.name = "RetrievalGraph"

# Example usage function
async def chat_example():
    config = {"configurable": {"thread_id": "chat1", "user_id": "alice123"}}
    
    # Session 1: Learns prefs + retrieves
    result1 = await graph.ainvoke({
        "messages": [HumanMessage(content="Hi I'm Alice, love hiking/tech. What's RAG?")]
    }, config)
    print("Response 1:", result1["messages"][-1].content)
    
    # New thread: Uses prefs in retrieval context!
    config2 = {"configurable": {"thread_id": "chat2", "user_id": "alice123"}}
    result2 = await graph.ainvoke({
        "messages": [HumanMessage(content="More on RAG, make it brief")]
    }, config2)
    print("Response 2:", result2["messages"][-1].content)