In [None]:
import os
from dotenv import load_dotenv
from psycopg import Connection
from langgraph.graph import Graph, END
from langgraph.store.postgres import PostgresStore
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.messages import HumanMessage
from langchain_google_genai import ChatGoogleGenerativeAI

# ----------------- SETUP -----------------
load_dotenv()

DB_URI = "postgresql://postgres:postgres@localhost:5432/rag?sslmode=disable"
embedding_dim = 1536  # Use reduced dimension

embeddings = GoogleGenerativeAIEmbeddings(
    model="models/gemini-embedding-001",
    google_api_key=os.getenv("GOOGLE_API_KEY"),
    embedding_kwargs={"output_dimensionality": embedding_dim},
)

conn = Connection.connect(DB_URI, autocommit=True)
store = PostgresStore(conn, index={"embed": embeddings, "dims": embedding_dim, "hnsw": False})
store.setup()

llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=os.getenv("GOOGLE_API_KEY"))

# ----------------- GRAPH LOGIC -----------------
namespace = ("1", "documents")

# --- Node: Start ---
def start_node(state):
    mode = input("Choose Mode: (1) Ingest Text/URL  (2) Chat : ").strip()
    state["mode"] = mode
    return "mode_selector"

# --- Node: Mode Selector ---
def mode_selector(state):
    if state["mode"] == "1":
        state["input_source"] = input("Enter story text or YouTube URL: ").strip()
        return "check_exists"
    elif state["mode"] == "2":
        return "chat_query"
    else:
        print("❌ Invalid mode selected.")
        return END

# --- Node: Check if URL/Text Exists ---
def check_exists(state):
    user_input = state["input_source"]
    results = store.search(namespace, query=user_input, limit=1)
    if results:
        print("✅ Already embedded, switching to chat mode...")
        return "chat_query"
    else:
        return "extract_content"

# --- Node: Extract or Transcribe Content ---
def extract_content(state):
    user_input = state["input_source"]
    if user_input.startswith("http"):
        # Placeholder — You can add yt_dlp + Whisper transcription here
        text = "This is a transcribed text from the given URL."
    else:
        text = user_input
    state["text"] = text
    return "split_chunks"

# --- Node: Split Text ---
def split_chunks(state):
    splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    state["chunks"] = splitter.split_text(state["text"])
    return "embed_chunks"

# --- Node: Embed and Store ---
def embed_chunks(state):
    chunks = state["chunks"]
    for i, chunk in enumerate(chunks):
        key = f"chunk_{i+1}"
        metadata = {"text": chunk}
        store.put(namespace, key, metadata, index=["text"])
    print(f"✅ Stored {len(chunks)} chunks in PostgreSQL.")
    return "generate_summary"

# --- Node: Generate Summary ---
def generate_summary(state):
    text = state["text"]
    summary = llm.invoke([HumanMessage(content=f"Summarize this story:\n\n{text}")])
    print("\n📝 Summary of the text:\n", summary.content)
    return "chat_query"

# --- Node: Chat with User ---
def chat_query(state):
    while True:
        query = input("\n💬 Ask a question (or 'exit'): ").strip()
        if query.lower() in ["exit", "quit"]:
            print("👋 Ending chat.")
            return END
        results = store.search(namespace, query=query, limit=3)
        context = "\n".join([r.value["text"] for r in results])
        prompt = f"Context:\n{context}\n\nUser Question: {query}"
        answer = llm.invoke([HumanMessage(content=prompt)])
        print("\n🤖 Answer:", answer.content)

# ----------------- BUILD GRAPH -----------------
graph = Graph()
graph.add_node("start", start_node)
graph.add_node("mode_selector", mode_selector)
graph.add_node("check_exists", check_exists)
graph.add_node("extract_content", extract_content)
graph.add_node("split_chunks", split_chunks)
graph.add_node("embed_chunks", embed_chunks)
graph.add_node("generate_summary", generate_summary)
graph.add_node("chat_query", chat_query)

# --- Edges ---
graph.set_entry_point("start")
graph.add_edge("start", "mode_selector")
graph.add_edge("mode_selector", "check_exists")
graph.add_edge("check_exists", "extract_content")
graph.add_edge("extract_content", "split_chunks")
graph.add_edge("split_chunks", "embed_chunks")
graph.add_edge("embed_chunks", "generate_summary")
graph.add_edge("generate_summary", "chat_query")
graph.add_edge("mode_selector", "chat_query")

app = graph.compile()
app

In [None]:
for i in range(10):
    print(f"\nSession {i+1}")

: 