In [3]:
!pip uninstall -y pinecone-client pinecone
!pip install --no-cache-dir -U pinecone langchain-pinecone


Found existing installation: pinecone-client 6.0.0
Uninstalling pinecone-client-6.0.0:
  Successfully uninstalled pinecone-client-6.0.0
Found existing installation: pinecone 7.3.0
Uninstalling pinecone-7.3.0:
  Successfully uninstalled pinecone-7.3.0
Collecting pinecone
  Downloading pinecone-7.3.0-py3-none-any.whl.metadata (9.5 kB)
Downloading pinecone-7.3.0-py3-none-any.whl (587 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m587.6/587.6 kB[0m [31m21.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pinecone
Successfully installed pinecone-7.3.0


In [4]:
pip install -U langchain langchain-core langchain-community langchain-text-splitters langchain-huggingface


Collecting langchain
  Using cached langchain-0.3.27-py3-none-any.whl.metadata (7.8 kB)
Collecting langchain-community
  Using cached langchain_community-0.3.27-py3-none-any.whl.metadata (2.9 kB)
Collecting langchain-huggingface
  Using cached langchain_huggingface-0.3.1-py3-none-any.whl.metadata (996 bytes)
Collecting pydantic-settings<3.0.0,>=2.4.0 (from langchain-community)
  Using cached pydantic_settings-2.10.1-py3-none-any.whl.metadata (3.4 kB)
Collecting httpx-sse<1.0.0,>=0.4.0 (from langchain-community)
  Using cached httpx_sse-0.4.1-py3-none-any.whl.metadata (9.4 kB)
Using cached langchain-0.3.27-py3-none-any.whl (1.0 MB)
Using cached langchain_community-0.3.27-py3-none-any.whl (2.5 MB)
Using cached httpx_sse-0.4.1-py3-none-any.whl (8.1 kB)
Using cached pydantic_settings-2.10.1-py3-none-any.whl (45 kB)
Using cached langchain_huggingface-0.3.1-py3-none-any.whl (27 kB)
Installing collected packages: httpx-sse, pydantic-settings, langchain-huggingface, langchain, langchain-commun

In [2]:
import pinecone, sys
print("pinecone version:", getattr(pinecone, "__version__", "unknown"))


pinecone version: 7.3.0


In [3]:
!pip install --upgrade --no-cache-dir \
  "huggingface_hub>=0.25.0" \
  "langchain-community>=0.2.10" \
  "langchain-core>=0.2.38"




In [2]:
import huggingface_hub, langchain_community
print("huggingface_hub:", huggingface_hub.__version__)
print("langchain_community:", langchain_community.__version__)
# Expect: huggingface_hub >= 0.25.x, langchain_community >= 0.2.10


huggingface_hub: 0.34.4
langchain_community: 0.3.27


In [1]:
import os
from dotenv import load_dotenv

load_dotenv()

# Map either var name to what our code expects
if os.getenv("HUGGINGFACEHUB_API_TOKEN") is None and os.getenv("HUGGINGFACEHUB_API_KEY"):
    os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_KEY")

pinecone_key = os.getenv("PINECONE_API_KEY")
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")

print("Pinecone key loaded:", bool(pinecone_key))
print("HF token loaded:", bool(hf_token))
assert pinecone_key and hf_token, "Missing PINECONE_API_KEY or HF token"


Pinecone key loaded: True
HF token loaded: True


In [2]:
from pinecone import Pinecone, ServerlessSpec
from langchain_pinecone import PineconeVectorStore

from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.document_loaders import TextLoader

from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Load & split data
loader = TextLoader("./horoscope.txt")
documents = loader.load()

text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=4)
docs = text_splitter.split_documents(documents)
print(f"Loaded {len(documents)} doc(s), split into {len(docs)} chunks.")


Loaded 1 doc(s), split into 7 chunks.


In [4]:
# Embeddings
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
dim = len(embeddings.embed_query("ping"))
print("Embedding dim:", dim)  # should print 768 for this model


Embedding dim: 768


In [5]:
# Pinecone: index + vector store
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))

INDEX_NAME = "langchain-demo"     # lowercase / digits / dashes only
CLOUD, REGION = "aws", "us-east-1"  # free tier-friendly region

# Create index if missing
if INDEX_NAME not in [i["name"] for i in pc.list_indexes()]:
    pc.create_index(
        name=INDEX_NAME,
        dimension=dim,
        metric="cosine",
        spec=ServerlessSpec(cloud=CLOUD, region=REGION),
    )

# If index is empty, upsert; otherwise just connect
idx = pc.Index(INDEX_NAME)
stats = idx.describe_index_stats()
total = stats.get("total_vector_count", 0)

if total == 0:
    vectorstore = PineconeVectorStore.from_documents(docs, embedding=embeddings, index_name=INDEX_NAME)
    print("Upserted docs to Pinecone.")
else:
    vectorstore = PineconeVectorStore(index_name=INDEX_NAME, embedding=embeddings)
    print(f"Connected to existing index with ~{total} vectors.")

print("Vectorstore ready ✅")


Connected to existing index with ~7 vectors.
Vectorstore ready ✅


In [6]:
res = vectorstore.similarity_search("lucky number for leo", k=3)
for i, r in enumerate(res, 1):
    print(f"\n[{i}] {r.page_content[:200]}...")



[1] When hardworking Saturn shifts into ambitious, competitive Aries at the end of March, you may feel renewed confidence in your career. Embracing leadership roles can boost your morale and help you rega...

[2] Love takes you to new heights this year, Sag. You’re all about seeking truth and revealing it. As Venus, the goddess of love, spends most of January in the dreamy, creative sign of Pisces, you’re unde...

[3] Sagittarius Horoscope of 2025

As the fiery Archer, your outgoing and adventurous nature is powered by Jupiter, your optimistic planet. This year, Jupiter joins forces with Gemini and Cancer in your s...


In [7]:
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})

def format_docs(docs):
    return "\n\n".join(d.page_content for d in docs)

print("Retriever ready ✅")


Retriever ready ✅


In [8]:
chat_llm = ChatHuggingFace(
    llm=HuggingFaceEndpoint(
        repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
        task="conversational",        # IMPORTANT for this model
        temperature=0.7,
        top_k=50,
        max_new_tokens=256,
        huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
    )
)

print("Chat LLM ready ✅")


Chat LLM ready ✅


In [9]:
resp = chat_llm.invoke("Answer briefly: What's the capital of France?")
print(getattr(resp, "content", resp))


 The capital of France is Paris. Known for its iconic landmarks like the Eiffel Tower, Louvre Museum, Notre-Dame Cathedral, and sophisticated cafes, Paris is one of the most popular tourist destinations in the world. It's renowned for its art, fashion, gastronomy, and culture, making it a city that truly has something for everyone.


In [10]:
prompt = ChatPromptTemplate.from_messages([
    ("system",
     "You are a fortune teller. Use the provided context to answer. "
     "If you don't know, say you don't know. Keep the answer within 2 sentences and concise."),
    ("human", "Context:\n{context}\n\nQuestion: {question}")
])


In [11]:
# Compose: retrieve -> format -> prompt -> chat -> extract content
rag_chain = (
    {
        "context": retriever | RunnableLambda(format_docs),
        "question": RunnablePassthrough(),
    }
    | prompt
    | chat_llm
    | RunnableLambda(lambda msg: getattr(msg, "content", str(msg)))
)

print("RAG chain ready ✅")


RAG chain ready ✅


In [12]:
print(rag_chain.invoke("What does this week look like for Leo?"))


 I don't have real-time information or personal data, so I can't provide specific predictions for this week for Leo. However, based on the provided context, there is no Leo-specific information available, so it would be best for Leo to follow the general guidance: stay aware of financial shifts, consider a long-term plan, and consult a professional if needed. No career or love horoscope was given for Leo this week.


In [13]:
print(rag_chain.invoke("How is my life going to be iin 2025?"))

 In 2025, you'll experience growth and transformation in your relationships due to Jupiter's influence, but be mindful of Mercury retrograde periods disrupting your travel plans. Financially, stay aware of global events and innovations, considering a long-term financial plan. Your love life may see deeper connections and fresh perspectives towards the end of the year.


In [14]:
# main.py
import os
from dotenv import load_dotenv
from pinecone import Pinecone, ServerlessSpec
from langchain_pinecone import PineconeVectorStore

from langchain_huggingface import (
    HuggingFaceEmbeddings,
    ChatHuggingFace,
    HuggingFaceEndpoint,
)
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.output_parsers import StrOutputParser


class ChatBot:
    def __init__(self, index_name: str = "langchain-demo", cloud: str = "aws", region: str = "us-east-1"):
        load_dotenv()

        # HF env compatibility
        if os.getenv("HUGGINGFACEHUB_API_TOKEN") is None and os.getenv("HUGGINGFACEHUB_API_KEY"):
            os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_KEY")

        pinecone_key = os.getenv("PINECONE_API_KEY")
        hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")

        if not pinecone_key or not hf_token:
            raise RuntimeError("Missing PINECONE_API_KEY or HUGGINGFACEHUB_API_TOKEN in environment.")

        # ------- Load & split docs -------
        if not os.path.exists("./horoscope.txt"):
            raise FileNotFoundError("Couldn't find './horoscope.txt'. Make sure the file exists.")

        loader = TextLoader("./horoscope.txt")
        documents = loader.load()
        text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=4)
        self.docs = text_splitter.split_documents(documents)

        # ------- Embeddings -------
        self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
        dim = len(self.embeddings.embed_query("ping"))  # expected 768

        # ------- Pinecone setup -------
        pc = Pinecone(api_key=pinecone_key)
        self.index_name = index_name

        # Handle both dict and object forms from pc.list_indexes() across client versions
        def _idx_name(x):
            return x.name if hasattr(x, "name") else (x.get("name") if isinstance(x, dict) else None)

        existing = {_idx_name(i) for i in pc.list_indexes()}
        if self.index_name not in existing:
            pc.create_index(
                name=self.index_name,
                dimension=dim,
                metric="cosine",
                spec=ServerlessSpec(cloud=cloud, region=region),
            )

        idx = pc.Index(self.index_name)
        stats = idx.describe_index_stats()

        # Total vectors across namespaces
        namespaces = stats.get("namespaces", {}) or {}
        total = sum(ns.get("vector_count", 0) for ns in namespaces.values())

        # ------- Vector store -------
        if total == 0:
            self.vectorstore = PineconeVectorStore.from_documents(
                self.docs, embedding=self.embeddings, index_name=self.index_name
            )
        else:
            self.vectorstore = PineconeVectorStore(index_name=self.index_name, embedding=self.embeddings)

        # Optional quick sanity check (safe even if no results)
        try:
            res = self.vectorstore.similarity_search("lucky number for leo", k=3)
            for i, r in enumerate(res, 1):
                # print(f"[{i}] {r.page_content[:200]}...")
                pass
        except Exception:
            pass

        self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 3})

        # ------- LLM -------
        self.chat_llm = ChatHuggingFace(
            llm=HuggingFaceEndpoint(
                repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
                task="text-generation",
                temperature=0.7,
                top_k=50,
                max_new_tokens=256,
                do_sample=True,
                return_full_text=False,
                huggingfacehub_api_token=hf_token,
            )
        )

        # ------- Prompt -------
        prompt = ChatPromptTemplate.from_messages([
            ("system",
             "You are a fortune teller. Use the provided context to answer. "
             "If you don't know, say you don't know. Keep the answer within 2 sentences and concise."),
            ("human", "Context:\n{context}\n\nQuestion: {question}")
        ])

        # Keep as a nested function to avoid capturing self in RunnableLambda
        def format_docs(docs):
            return "\n\n".join(d.page_content for d in docs)

        # ------- RAG chain -------
        self.rag_chain = (
            {
                "context": self.retriever | RunnableLambda(format_docs),
                "question": RunnablePassthrough(),
            }
            | prompt
            | self.chat_llm
            | StrOutputParser()
        )


if __name__ == "__main__":
    bot = ChatBot()
    user_q = input("Ask me anything: ")
    result = bot.rag_chain.invoke(user_q)
    print(result)


 In 2


In [15]:
import os, time, logging
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain.retrievers import ContextualCompressionRetriever

assert "vectorstore" in globals(), "Run earlier cells to create `vectorstore` first."
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)

# 1) Query rewriter LLM (fast, deterministic)
rewriter_llm = ChatHuggingFace(
    llm=HuggingFaceEndpoint(
        repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
        task="text-generation",
        temperature=0.2,
        do_sample=False,
        top_k=None,
        max_new_tokens=96,
        return_full_text=False,
    )
)
# 2) Base retriever → MultiQuery
base_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
mqr = MultiQueryRetriever.from_llm(retriever=base_retriever, llm=rewriter_llm)

# 3) Cross-encoder reranker (IMPORTANT: use model=, not cross_encoder=)
print("Loading reranker (BAAI/bge-reranker-base)…")
cross_encoder = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
reranker = CrossEncoderReranker(model=cross_encoder, top_n=5)  # <-- fix here

# 4) Final advanced retriever
advanced_retriever = ContextualCompressionRetriever(
    base_compressor=reranker,
    base_retriever=mqr
)
print("Advanced retriever ready (multi-query + rerank + compression) ✅")

# 5) Smoke test
q = "What can Gemini expect this week?"
docs = advanced_retriever.get_relevant_documents(q)
print("Retrieved:", len(docs), "docs")
print(docs[0].page_content[:200] + "…" if docs else "No docs")



Loading reranker (BAAI/bge-reranker-base)…
Advanced retriever ready (multi-query + rerank + compression) ✅


  docs = advanced_retriever.get_relevant_documents(q)


Retrieved: 5 docs
Sagittarius Horoscope of 2025

As the fiery Archer, your outgoing and adventurous nature is powered by Jupiter, your optimistic planet. This year, Jupiter joins forces with Gemini and Cancer in your s…


In [17]:
# Helpers: context formatting + citations
from typing import List
from langchain_core.documents import Document  # <-- add this import

def format_context(docs: List[Document]) -> str:
    # join only non-empty page_content
    return "\n\n".join(
        getattr(d, "page_content", "") for d in docs if getattr(d, "page_content", "")
    )

def extract_citations(docs: List[Document]) -> List[str]:
    cites = []
    for i, d in enumerate(docs, 1):
        md = (getattr(d, "metadata", {}) or {})
        src = md.get("source") or md.get("file") or "horoscope.txt"
        cites.append(f"[{i}] {src}")
    return cites


In [18]:
# Tools (example): lucky number + now
from langchain_core.tools import tool
import hashlib, datetime

@tool("lucky_number")
def lucky_number(name_or_sign: str) -> str:
    """
    Deterministic 'lucky number' (1-9) from a name or zodiac sign.
    """
    h = int(hashlib.md5(name_or_sign.strip().lower().encode("utf-8")).hexdigest(), 16)
    num = (h % 9) + 1
    return f"Lucky number for '{name_or_sign}': {num}"

@tool("now")
def now(_: str = "") -> str:
    """
    Current date/time (ISO format).
    """
    return datetime.datetime.now().isoformat(timespec="seconds")

In [19]:
# LangGraph pipeline: route → tools/RAG → grade → generate (fixed imports)
from typing import TypedDict, List, Literal, Optional
from langgraph.graph import StateGraph, START, END
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.documents import Document

# --- sanity checks ---
assert "advanced_retriever" in globals(), "Build advanced_retriever first"
assert "chat_llm" in globals(), "Build chat_llm first"
assert "lucky_number" in globals() and "now" in globals(), "Define tools first"
assert "extract_citations" in globals() and "format_context" in globals(), "Add helper funcs first"

# (Optional) sign extractor to make tool output cleaner
SIGNS = [
    "aries","taurus","gemini","cancer","leo","virgo",
    "libra","scorpio","sagittarius","capricorn","aquarius","pisces"
]
def _extract_sign_or_name(q: str) -> str:
    t = (q or "").lower()
    for s in SIGNS:
        if s in t:
            return s.title()
    return q.strip()

class RAGState(TypedDict, total=False):
    session_id: str
    question: str
    history: List                 # for memory wrapper
    route: Literal["TOOLS", "RAG"]
    docs: List[Document]
    citations: List[str]
    tool_result: Optional[str]
    grounded: bool
    answer: str

TOOL_KEYWORDS = ("lucky number", "lucky", "today", "date", "time", "now")

def route_node(state: RAGState) -> RAGState:
    q = (state.get("question") or "").lower()
    return {**state, "route": "TOOLS" if any(k in q for k in TOOL_KEYWORDS) else "RAG"}

def tools_node(state: RAGState) -> RAGState:
    q = state["question"]
    if "lucky" in q.lower():
        target = _extract_sign_or_name(q)
        result = lucky_number.invoke(target)   # e.g., "Lucky number for 'Gemini': 8"
    else:
        result = now.invoke("")
    # Short-circuit: finalize answer here (skip LLM)
    return {**state, "tool_result": result, "answer": result,
            "docs": [], "citations": [], "grounded": True}

def retrieve_node(state: RAGState) -> RAGState:
    docs = advanced_retriever.invoke(state["question"])   # new API
    return {**state, "docs": docs, "citations": extract_citations(docs)}

def grade_node(state: RAGState) -> RAGState:
    return {**state, "grounded": len(state.get("docs") or []) > 0}

# Tight prompt: only cite provided citations; if none, cite nothing.
rag_prompt = ChatPromptTemplate.from_messages([
    ("system",
     "You are a concise fortune teller. Use ONLY the provided context. "
     "Rules: (1) If context is empty or irrelevant, say you don't know. "
     "(2) Cite ONLY items listed under 'Citations:'. "
     "(3) If 'Citations:' is empty, do not cite anything."),
    MessagesPlaceholder("history"),
    ("human", "Context:\n{context}\n\nCitations:\n{citations}\n\nQuestion: {question}")
])
generator = rag_prompt | chat_llm | StrOutputParser()

def generate_node(state: RAGState) -> RAGState:
    # If tools already answered, do nothing
    if state.get("tool_result"):
        return state
    context = format_context(state.get("docs", [])) if state.get("docs") else ""
    citations_text = " ".join(state.get("citations", []))
    answer = generator.invoke({
        "history": state.get("history", []),
        "context": context,
        "citations": citations_text,
        "question": state["question"],
    })
    return {**state, "answer": answer}

graph = StateGraph(RAGState)
graph.add_node("route", route_node)
graph.add_node("tools", tools_node)
graph.add_node("retrieve", retrieve_node)
graph.add_node("grade", grade_node)
graph.add_node("generate", generate_node)

graph.add_edge(START, "route")
graph.add_conditional_edges("route", lambda s: "tools" if s["route"] == "TOOLS" else "retrieve")
graph.add_edge("tools", END)                  # <-- tools short-circuit to END
graph.add_edge("retrieve", "grade")
graph.add_conditional_edges("grade", lambda s: "generate" if s["grounded"] else END)
graph.add_edge("generate", END)

rag_app = graph.compile()
print("LangGraph app compiled ✅")


LangGraph app compiled ✅


In [20]:
# Make the graph return only the final answer string
rag_app_answer = rag_app | (lambda state: state["answer"])

# Memory wrapper
try:
    from langchain_community.chat_message_histories import ChatMessageHistory as InMemoryChatMessageHistory
except ImportError:
    from langchain_community.chat_message_histories import InMemoryChatMessageHistory

from langchain_core.runnables.history import RunnableWithMessageHistory

_session_store = {}
def get_session_history(session_id: str) -> InMemoryChatMessageHistory:
    if session_id not in _session_store:
        _session_store[session_id] = InMemoryChatMessageHistory()
    return _session_store[session_id]

rag_app_with_memory = RunnableWithMessageHistory(
    rag_app_answer,                 # <-- string output runnable
    get_session_history,
    input_messages_key="question",  # must match MessagesPlaceholder("history")
    history_messages_key="history",
)
print("Memory wrapper ready ✅")

Memory wrapper ready ✅


In [21]:
#Try it (RAG and Tools paths)
cfg = {"configurable": {"session_id": "demo-user-1"}}

print("RAG path (uses your corpus):")
ans1 = rag_app_with_memory.invoke({"question": "What can Gemini expect this week?"}, config=cfg)
print(ans1)   # <-- string

print("\nTools path (short-circuits to tool output):")
ans2 = rag_app_with_memory.invoke({"question": "What's my lucky number for Gemini?"}, config=cfg)
print(ans2)   # <-- string (no fake citations)

RAG path (uses your corpus):
 I don't have information specific to Gemini for this week, as the context provided refers to Sagittarius Horoscope of 2025, Sagittarius Career Horoscope, and Sagittarius Love Horoscope, without any weekly updates.

Tools path (short-circuits to tool output):
Lucky number for 'Gemini': 3


In [22]:
# Helper for Streamlit to call the graph
def answer_with_graph(user_text: str, session_id: str = "web-user-1") -> str:
    cfg = {"configurable": {"session_id": session_id}}
    state = {"session_id": session_id, "question": user_text}
    out = rag_app_with_memory.invoke(state, config=cfg)
    ans = out.get("answer", "")
    cites = out.get("citations", [])
    if cites:
        ans = f"{ans}\n\nSources: " + " ".join(cites)
    return ans

print("answer_with_graph() ready ✅")

answer_with_graph() ready ✅


In [1]:
# main.py
import os
from typing import List, Optional, Literal, TypedDict

from dotenv import load_dotenv

# Pinecone + Vector store
from pinecone import Pinecone, ServerlessSpec
from langchain_pinecone import PineconeVectorStore

# Embeddings / LLM
from langchain_huggingface import (
    HuggingFaceEmbeddings,
    ChatHuggingFace,
    HuggingFaceEndpoint,
)

# Data loading / splitting
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.document_loaders import TextLoader

# LangChain core
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.documents import Document
from langchain_core.runnables import RunnableLambda

# Advanced retrieval
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain.retrievers import ContextualCompressionRetriever

# Tools
from langchain_core.tools import tool
import hashlib, datetime

# LangGraph
from langgraph.graph import StateGraph, START, END

# Memory (version-agnostic import)
try:
    # Newer LangChain
    from langchain_community.chat_message_histories import ChatMessageHistory as InMemoryChatMessageHistory
except ImportError:
    # Older LangChain
    from langchain_community.chat_message_histories import InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory


class ChatBot:
    def __init__(
        self,
        index_name: str = "langchain-demo",
        cloud: str = "aws",
        region: str = "us-east-1",
    ):
        load_dotenv()

        # HF env compatibility
        if os.getenv("HUGGINGFACEHUB_API_TOKEN") is None and os.getenv("HUGGINGFACEHUB_API_KEY"):
            os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_KEY")

        pinecone_key = os.getenv("PINECONE_API_KEY")
        hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
        if not pinecone_key or not hf_token:
            raise RuntimeError("Missing PINECONE_API_KEY or HUGGINGFACEHUB_API_TOKEN in environment.")

        # ------- Load & split docs -------
        if not os.path.exists("./horoscope.txt"):
            raise FileNotFoundError("Couldn't find './horoscope.txt'. Make sure the file exists.")
        loader = TextLoader("./horoscope.txt")
        documents = loader.load()
        text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=4)
        docs = text_splitter.split_documents(documents)

        # ------- Embeddings -------
        self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
        dim = len(self.embeddings.embed_query("ping"))  # expected 768

        # ------- Pinecone setup -------
        pc = Pinecone(api_key=pinecone_key)
        self.index_name = index_name

        def _idx_name(x):
            return x.name if hasattr(x, "name") else (x.get("name") if isinstance(x, dict) else None)

        existing = {_idx_name(i) for i in pc.list_indexes()}
        if self.index_name not in existing:
            pc.create_index(
                name=self.index_name,
                dimension=dim,
                metric="cosine",
                spec=ServerlessSpec(cloud=cloud, region=region),
            )

        idx = pc.Index(self.index_name)
        stats = idx.describe_index_stats()
        namespaces = stats.get("namespaces", {}) or {}
        total = sum(ns.get("vector_count", 0) for ns in namespaces.values()) if namespaces else stats.get("total_vector_count", 0) or 0

        if total == 0:
            self.vectorstore = PineconeVectorStore.from_documents(docs, embedding=self.embeddings, index_name=self.index_name)
        else:
            self.vectorstore = PineconeVectorStore(index_name=self.index_name, embedding=self.embeddings)

        # ------- LLMs -------
        HF_TOKEN = hf_token

        # Final answerer (fluent)
        self.chat_llm = ChatHuggingFace(
            llm=HuggingFaceEndpoint(
                repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
                task="text-generation",
                temperature=0.7,
                do_sample=True,
                top_k=50,
                max_new_tokens=256,
                return_full_text=False,
                huggingfacehub_api_token=HF_TOKEN,
            )
        )

        # Deterministic rewriter (can be same model; different decoding)
        self.rewriter_llm = ChatHuggingFace(
            llm=HuggingFaceEndpoint(
                repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
                task="text-generation",
                temperature=0.2,
                do_sample=False,
                max_new_tokens=96,
                return_full_text=False,
                huggingfacehub_api_token=HF_TOKEN,
            )
        )

        # ------- Advanced retriever (MultiQuery + Cross-Encoder rerank + Compression) -------
        base_retriever = self.vectorstore.as_retriever(search_kwargs={"k": 10})
        mqr = MultiQueryRetriever.from_llm(retriever=base_retriever, llm=self.rewriter_llm)

        reranker_model_name = os.getenv("RERANKER_MODEL", "BAAI/bge-reranker-base")  # switch to -large if you want
        cross_encoder = HuggingFaceCrossEncoder(model_name=reranker_model_name)
        reranker = CrossEncoderReranker(model=cross_encoder, top_n=5)

        self.advanced_retriever = ContextualCompressionRetriever(
            base_compressor=reranker,
            base_retriever=mqr,
        )

        # ------- Helpers -------
        self.SIGNS = [
            "aries","taurus","gemini","cancer","leo","virgo",
            "libra","scorpio","sagittarius","capricorn","aquarius","pisces"
        ]

        def extract_sign_from_text(text: str) -> Optional[str]:
            t = (text or "").lower()
            for s in self.SIGNS:
                if s in t:
                    return s
            return None

        def filter_docs_by_sign(docs_list: List[Document], sign: Optional[str]) -> List[Document]:
            if not sign:
                return docs_list
            s = sign.lower()
            return [d for d in docs_list if s in (d.page_content or "").lower()]

        def extract_citations(docs_list: List[Document]) -> List[str]:
            cites = []
            for i, d in enumerate(docs_list, 1):
                md = (getattr(d, "metadata", {}) or {})
                src = md.get("source") or md.get("file") or "horoscope.txt"
                cites.append(f"[{i}] {src}")
            return cites

        def format_context(docs_list: List[Document]) -> str:
            return "\n\n".join(
                getattr(d, "page_content", "") for d in docs_list if getattr(d, "page_content", "")
            )

        # ------- Tools -------
        @tool("lucky_number")
        def lucky_number(name_or_sign: str) -> str:
            """Deterministic 'lucky number' (1-9) from a name or zodiac sign."""
            h = int(hashlib.md5(name_or_sign.strip().lower().encode("utf-8")).hexdigest(), 16)
            num = (h % 9) + 1
            return f"Lucky number for '{name_or_sign}': {num}"

        @tool("now")
        def now(_: str = "") -> str:
            """Current date/time (ISO format)."""
            return datetime.datetime.now().isoformat(timespec="seconds")

        self.lucky_number = lucky_number
        self.now = now

        # ------- LangGraph (sign-aware): route → tools/RAG → grade → fallback/generate -------
        class RAGState(TypedDict, total=False):
            session_id: str
            question: str
            history: List
            route: Literal["TOOLS", "RAG"]
            target_sign: Optional[str]
            docs: List[Document]
            citations: List[str]
            tool_result: Optional[str]
            grounded: bool
            answer: str

        TOOL_KEYWORDS = ("lucky number", "lucky", "today", "date", "time", "now")

        def route_node(state: RAGState) -> RAGState:
            q = (state.get("question") or "")
            route = "TOOLS" if any(k in q.lower() for k in TOOL_KEYWORDS) else "RAG"
            return {**state, "route": route, "target_sign": extract_sign_from_text(q)}

        def tools_node(state: RAGState) -> RAGState:
            q = state["question"]
            if "lucky" in q.lower():
                target = extract_sign_from_text(q) or q.strip()
                result = self.lucky_number.invoke(target.title())
            else:
                result = self.now.invoke("")
            return {**state, "tool_result": result, "answer": result, "docs": [], "citations": [], "grounded": True}

        def retrieve_node(state: RAGState) -> RAGState:
            docs_list = self.advanced_retriever.invoke(state["question"])  # new API
            sign = state.get("target_sign")
            if sign:
                docs_list = filter_docs_by_sign(docs_list, sign)
            return {**state, "docs": docs_list, "citations": extract_citations(docs_list)}

        def grade_node(state: RAGState) -> RAGState:
            grounded = len(state.get("docs") or []) > 0
            return {**state, "grounded": grounded}

        def fallback_node(state: RAGState) -> RAGState:
            sign = state.get("target_sign")
            if sign:
                msg = f"I don't have content for {sign.title()} yet. Please add it to the corpus."
            else:
                msg = "I don't know yet. Please add relevant content to the corpus."
            return {**state, "answer": msg}

        rag_prompt = ChatPromptTemplate.from_messages([
            ("system",
             "You are a helpful, concise fortune teller. Use ONLY the provided context and tools. "
             "If you don't know, say you don't know. Keep answers within 2 short sentences."),
            MessagesPlaceholder("history"),
            ("human",
             "Context:\n{context}\n\nTools:\n{tool_result}\n\nQuestion: {question}\n"
             "Cite sources as [1], [2] if used.")
        ])
        generator = rag_prompt | self.chat_llm | StrOutputParser()

        def generate_node(state: RAGState) -> RAGState:
            # If tools already set the answer, skip generation
            if state.get("tool_result"):
                return state
            context = format_context(state.get("docs", [])) if state.get("docs") else ""
            tool_text = state.get("tool_result", "")
            answer = generator.invoke({
                "history": state.get("history", []),
                "context": context,
                "tool_result": tool_text,
                "question": state["question"],
            })
            return {**state, "answer": answer}

        graph = StateGraph(RAGState)
        graph.add_node("route", route_node)
        graph.add_node("tools", tools_node)
        graph.add_node("retrieve", retrieve_node)
        graph.add_node("grade", grade_node)
        graph.add_node("fallback", fallback_node)
        graph.add_node("generate", generate_node)

        graph.add_edge(START, "route")
        graph.add_conditional_edges("route", lambda s: "tools" if s["route"] == "TOOLS" else "retrieve")
        graph.add_edge("tools", END)  # tools short-circuit to END
        graph.add_edge("retrieve", "grade")
        graph.add_conditional_edges("grade", lambda s: "generate" if s["grounded"] else "fallback")
        graph.add_edge("generate", END)
        graph.add_edge("fallback", END)

        self.graph_app = graph.compile()

        # ------- Memory wrapper -------
        self._session_store = {}

        def get_session_history(session_id: str) -> InMemoryChatMessageHistory:
            if session_id not in self._session_store:
                self._session_store[session_id] = InMemoryChatMessageHistory()
            return self._session_store[session_id]

        # Return only final string answer from the graph
        answer_only = self.graph_app | RunnableLambda(lambda state: state["answer"])

        self.graph_with_memory = RunnableWithMessageHistory(
            answer_only,
            get_session_history,
            input_messages_key="question",
            history_messages_key="history",
        )

        # ------- Back-compat: rag_chain.invoke(question) returns string via graph -------
        class _Invoker:
            def __init__(self, outer):
                self._outer = outer
            def invoke(self, question: str, session_id: str = "default"):
                cfg = {"configurable": {"session_id": session_id}}
                return self._outer.graph_with_memory.invoke({"session_id": session_id, "question": question}, config=cfg)

        self.rag_chain = _Invoker(self)

    # Convenience method for callers
    def answer_with_graph(self, question: str, session_id: str = "default") -> str:
        cfg = {"configurable": {"session_id": session_id}}
        return self.graph_with_memory.invoke({"session_id": session_id, "question": question}, config=cfg)


if __name__ == "__main__":
    bot = ChatBot()
    try:
        q = input("Ask me anything: ")
    except EOFError:
        q = "What can Sagittarius expect this week?"
    ans = bot.answer_with_graph(q, session_id="cli-user")
    print(ans)


  from .autonotebook import tqdm as notebook_tqdm


 The capital of France is Paris. [1] https://www.worldometers.info/capitals/france-capital/


In [3]:
# from your conda env / venv
!pip install -U langchain-groq groq


Collecting langchain-groq
  Downloading langchain_groq-0.3.7-py3-none-any.whl.metadata (2.6 kB)
Collecting groq
  Downloading groq-0.31.0-py3-none-any.whl.metadata (16 kB)
Downloading langchain_groq-0.3.7-py3-none-any.whl (16 kB)
Downloading groq-0.31.0-py3-none-any.whl (131 kB)
Installing collected packages: groq, langchain-groq
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2/2[0m [langchain-groq]
[1A[2KSuccessfully installed groq-0.31.0 langchain-groq-0.3.7


In [5]:
from langchain_groq import ChatGroq 
print('OK:', ChatGroq)

OK: <class 'langchain_groq.chat_models.ChatGroq'>


In [6]:
#main.py for groq
# main.py
import os
from typing import List, Optional, Literal, TypedDict

from dotenv import load_dotenv

# Pinecone + Vector store
from pinecone import Pinecone, ServerlessSpec
from langchain_pinecone import PineconeVectorStore

# Embeddings (local HF download; NOT HF inference)
from langchain_huggingface import HuggingFaceEmbeddings

# Data loading / splitting
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.document_loaders import TextLoader

# LangChain core
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.documents import Document
from langchain_core.runnables import RunnableLambda

# Advanced retrieval
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain.retrievers import ContextualCompressionRetriever

# Tools
from langchain_core.tools import tool
import hashlib, datetime, re

# LangGraph
from langgraph.graph import StateGraph, START, END

# Memory (version-agnostic)
try:
    from langchain_community.chat_message_histories import ChatMessageHistory as InMemoryChatMessageHistory
except ImportError:
    from langchain_community.chat_message_histories import InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory

# LLMs via Groq (free tier available)
from langchain_groq import ChatGroq


class ChatBot:
    def __init__(
        self,
        index_name: str = "langchain-demo",
        cloud: str = "aws",
        region: str = "us-east-1",
    ):
        load_dotenv()

        # --- Required env ---
        pinecone_key = os.getenv("PINECONE_API_KEY")
        groq_key = os.getenv("GROQ_API_KEY")
        if not pinecone_key:
            raise RuntimeError("Missing PINECONE_API_KEY in environment.")
        if not groq_key:
            raise RuntimeError("Missing GROQ_API_KEY in environment.")

        # --- Load & split docs ---
        if not os.path.exists("./horoscope.txt"):
            raise FileNotFoundError("Couldn't find './horoscope.txt'. Make sure the file exists.")
        loader = TextLoader("./horoscope.txt")
        documents = loader.load()
        text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=4)
        docs = text_splitter.split_documents(documents)

        # --- Embeddings (local; no paid inference) ---
        self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
        dim = len(self.embeddings.embed_query("ping"))  # expected 768

        # --- Pinecone setup ---
        pc = Pinecone(api_key=pinecone_key)
        self.index_name = index_name

        def _idx_name(x):
            return x.name if hasattr(x, "name") else (x.get("name") if isinstance(x, dict) else None)

        existing = {_idx_name(i) for i in pc.list_indexes()}
        if self.index_name not in existing:
            pc.create_index(
                name=self.index_name,
                dimension=dim,
                metric="cosine",
                spec=ServerlessSpec(cloud=cloud, region=region),
            )

        idx = pc.Index(self.index_name)
        stats = idx.describe_index_stats()
        namespaces = stats.get("namespaces", {}) or {}
        total = sum(ns.get("vector_count", 0) for ns in namespaces.values()) if namespaces else stats.get("total_vector_count", 0) or 0

        if total == 0:
            self.vectorstore = PineconeVectorStore.from_documents(docs, embedding=self.embeddings, index_name=self.index_name)
        else:
            self.vectorstore = PineconeVectorStore(index_name=self.index_name, embedding=self.embeddings)

        # --- LLMs (Groq) ---
        # fast, cheap model for chat + a stricter one for rewriting
        self.chat_llm = ChatGroq(model_name="llama-3.1-8b-instant", temperature=0.3, max_tokens=200)
        self.rewriter_llm = ChatGroq(model_name="llama-3.1-8b-instant", temperature=0.1, max_tokens=96)

        # --- Advanced retrieval (MultiQuery + rerank + compression) ---
        base_retriever = self.vectorstore.as_retriever(search_kwargs={"k": 10})
        mqr = MultiQueryRetriever.from_llm(retriever=base_retriever, llm=self.rewriter_llm)

        reranker_model_name = os.getenv("RERANKER_MODEL", "BAAI/bge-reranker-base")  # change to -large if you prefer
        cross_encoder = HuggingFaceCrossEncoder(model_name=reranker_model_name)
        reranker = CrossEncoderReranker(model=cross_encoder, top_n=5)

        self.advanced_retriever = ContextualCompressionRetriever(
            base_compressor=reranker,
            base_retriever=mqr,
        )

        # --- Helpers ---
        self.SIGNS = [
            "aries","taurus","gemini","cancer","leo","virgo",
            "libra","scorpio","sagittarius","capricorn","aquarius","pisces"
        ]

        def extract_sign_from_text(text: str) -> Optional[str]:
            t = (text or "").lower()
            for s in self.SIGNS:
                if s in t:
                    return s
            return None

        def filter_docs_by_sign(docs_list: List[Document], sign: Optional[str]) -> List[Document]:
            if not sign:
                return docs_list
            s = sign.lower()
            return [d for d in docs_list if s in (d.page_content or "").lower()]

        def extract_citations(docs_list: List[Document]) -> List[str]:
            cites = []
            for i, d in enumerate(docs_list, 1):
                md = (getattr(d, "metadata", {}) or {})
                src = md.get("source") or md.get("file") or "horoscope.txt"
                cites.append(f"[{i}] {src}")
            return cites

        def format_context(docs_list: List[Document]) -> str:
            return "\n\n".join(
                getattr(d, "page_content", "") for d in docs_list if getattr(d, "page_content", "")
            )

        self.extract_sign_from_text = extract_sign_from_text
        self.filter_docs_by_sign = filter_docs_by_sign
        self.extract_citations = extract_citations
        self.format_context = format_context

        # --- Tools ---
        @tool("lucky_number")
        def lucky_number(name_or_sign: str) -> str:
            """Deterministic 'lucky number' (1-9) from a name or zodiac sign."""
            h = int(hashlib.md5(name_or_sign.strip().lower().encode("utf-8")).hexdigest(), 16)
            num = (h % 9) + 1
            return f"Lucky number for '{name_or_sign}': {num}"

        @tool("now")
        def now(_: str = "") -> str:
            """Current date/time (ISO format)."""
            return datetime.datetime.now().isoformat(timespec="seconds")

        self.lucky_number = lucky_number
        self.now = now

        # --- LangGraph: route → tools/RAG → grade → fallback/generate (sign-aware) ---
        class RAGState(TypedDict, total=False):
            session_id: str
            question: str
            history: List
            route: Literal["TOOLS", "RAG"]
            target_sign: Optional[str]
            docs: List[Document]
            citations: List[str]
            tool_result: Optional[str]
            grounded: bool
            answer: str

        TOOL_KEYWORDS = ("lucky number", "lucky", "today", "date", "time", "now")

        def route_node(state: RAGState) -> RAGState:
            q = (state.get("question") or "")
            route = "TOOLS" if any(k in q.lower() for k in TOOL_KEYWORDS) else "RAG"
            return {**state, "route": route, "target_sign": self.extract_sign_from_text(q)}

        def tools_node(state: RAGState) -> RAGState:
            q = state["question"]
            if "lucky" in q.lower():
                target = self.extract_sign_from_text(q) or q.strip()
                result = self.lucky_number.invoke(target.title())
            else:
                result = self.now.invoke("")
            return {**state, "tool_result": result, "answer": result, "docs": [], "citations": [], "grounded": True}

        def retrieve_node(state: RAGState) -> RAGState:
            docs_list = self.advanced_retriever.invoke(state["question"])
            sign = state.get("target_sign")
            if sign:
                docs_list = self.filter_docs_by_sign(docs_list, sign)
            return {**state, "docs": docs_list, "citations": self.extract_citations(docs_list)}

        def grade_node(state: RAGState) -> RAGState:
            grounded = len(state.get("docs") or []) > 0
            return {**state, "grounded": grounded}

        def fallback_node(state: RAGState) -> RAGState:
            sign = state.get("target_sign")
            if sign:
                msg = f"I don't have content for {sign.title()} yet. Please add it to the corpus."
            else:
                msg = "I don't know yet. Please add relevant content to the corpus."
            return {**state, "answer": msg}

        rag_prompt = ChatPromptTemplate.from_messages([
            ("system",
             "You are a concise fortune teller.\n"
             "- Use ONLY the provided context and tool output.\n"
             "- Answer the user's single question in 1–2 sentences.\n"
             "- Do NOT invent new questions or headings.\n"
             "- If the context is empty/irrelevant, reply exactly: I don't know."),
            MessagesPlaceholder("history"),
            ("human",
             "Context:\n{context}\n\nTool:\n{tool_result}\n\nUser question: {question}\n"
             "Citations (optional): {citations}")
        ])
        generator = rag_prompt | self.chat_llm | StrOutputParser()

        def generate_node(state: RAGState) -> RAGState:
            if state.get("tool_result"):
                return state
            context = self.format_context(state.get("docs", [])) if state.get("docs") else ""
            tool_text = state.get("tool_result", "")
            cites = " ".join(state.get("citations", []))
            answer = generator.invoke({
                "history": state.get("history", []),
                "context": context,
                "tool_result": tool_text,
                "question": state["question"],
                "citations": cites,
            })
            answer = re.sub(r'(?mi)^\s*Question:.*$', '', answer).strip()
            return {**state, "answer": answer}

        graph = StateGraph(RAGState)
        graph.add_node("route", route_node)
        graph.add_node("tools", tools_node)
        graph.add_node("retrieve", retrieve_node)
        graph.add_node("grade", grade_node)
        graph.add_node("fallback", fallback_node)
        graph.add_node("generate", generate_node)

        graph.add_edge(START, "route")
        graph.add_conditional_edges("route", lambda s: "tools" if s["route"] == "TOOLS" else "retrieve")
        graph.add_edge("tools", END)
        graph.add_edge("retrieve", "grade")
        graph.add_conditional_edges("grade", lambda s: "generate" if s["grounded"] else "fallback")
        graph.add_edge("generate", END)
        graph.add_edge("fallback", END)

        self.graph_app = graph.compile()

        # --- Memory wrapper ---
        self._session_store = {}
        def get_session_history(session_id: str) -> InMemoryChatMessageHistory:
            if session_id not in self._session_store:
                self._session_store[session_id] = InMemoryChatMessageHistory()
            return self._session_store[session_id]

        answer_only = self.graph_app | RunnableLambda(lambda state: state["answer"])
        self.graph_with_memory = RunnableWithMessageHistory(
            answer_only,
            get_session_history,
            input_messages_key="question",
            history_messages_key="history",
        )

        # Back-compat for Streamlit
        class _Invoker:
            def __init__(self, outer):
                self._outer = outer
            def invoke(self, question: str, session_id: str = "default"):
                cfg = {"configurable": {"session_id": session_id}}
                return self._outer.graph_with_memory.invoke({"session_id": session_id, "question": question}, config=cfg)
        self.rag_chain = _Invoker(self)

    # Convenience method
    def answer_with_graph(self, question: str, session_id: str = "default") -> str:
        cfg = {"configurable": {"session_id": session_id}}
        return self.graph_with_memory.invoke({"session_id": session_id, "question": question}, config=cfg)


if __name__ == "__main__":
    bot = ChatBot()
    try:
        q = input("Ask me anything: ")
    except EOFError:
        q = "What can Sagittarius expect this week?"
    ans = bot.answer_with_graph(q, session_id="cli-user")
    print(ans)


As an Aries, you may face communication and tech disruptions this year due to Mercury retrograde periods, potentially affecting your travel plans and daily routines. However, by the end of November, the sun's focus on your first house will bring opportunities for deeper connections and new experiences.


In [2]:
# main.py
import os
import re
from typing import List, Optional, Literal, TypedDict

from dotenv import load_dotenv

# Pinecone + Vector store
from pinecone import Pinecone, ServerlessSpec
from langchain_pinecone import PineconeVectorStore

# Embeddings (local HF download; NOT HF inference)
from langchain_huggingface import HuggingFaceEmbeddings

# Data loading / splitting
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.document_loaders import TextLoader

# LangChain core
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.documents import Document
from langchain_core.runnables import RunnableLambda

# Advanced retrieval
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain.retrievers import ContextualCompressionRetriever

# Tools
from langchain_core.tools import tool
import hashlib, datetime

# LangGraph
from langgraph.graph import StateGraph, START, END

# Memory (version-agnostic)
try:
    from langchain_community.chat_message_histories import ChatMessageHistory as InMemoryChatMessageHistory
except ImportError:
    from langchain_community.chat_message_histories import InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory

# LLMs via Groq
from langchain_groq import ChatGroq


class ChatBot:
    def __init__(
        self,
        index_name: str = "langchain-demo-signed",   # clean index for sign metadata
        cloud: str = "aws",
        region: str = "us-east-1",
    ):
        load_dotenv()

        # --- Required env ---
        pinecone_key = os.getenv("PINECONE_API_KEY")
        groq_key = os.getenv("GROQ_API_KEY")
        if not pinecone_key:
            raise RuntimeError("Missing PINECONE_API_KEY in environment.")
        if not groq_key:
            raise RuntimeError("Missing GROQ_API_KEY in environment.")

        # --- Load & split docs ---
        if not os.path.exists("./horoscope.txt"):
            raise FileNotFoundError("Couldn't find './horoscope.txt'. Make sure the file exists.")
        loader = TextLoader("./horoscope.txt")
        documents = loader.load()
        text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=4)
        docs = text_splitter.split_documents(documents)

        # --- Tag each chunk with a dominant zodiac sign in metadata (NO NULLS) ---
        SIGN_LIST = [
            "aries","taurus","gemini","cancer","leo","virgo",
            "libra","scorpio","sagittarius","capricorn","aquarius","pisces"
        ]
        sign_re = re.compile(r"\b(" + "|".join(SIGN_LIST) + r")\b", flags=re.I)
        for d in docs:
            text = (d.page_content or "")
            found = [s.lower() for s in sign_re.findall(text)]
            dominant = max(set(found), key=found.count) if found else None
            d.metadata = (d.metadata or {})
            # Only set the key if we actually have a value (avoid null)
            if dominant:
                d.metadata["sign"] = dominant  # e.g., "sagittarius"
            else:
                d.metadata.pop("sign", None)
            d.metadata.setdefault("source", "horoscope.txt")

        # --- Embeddings (local; no paid inference) ---
        self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
        dim = len(self.embeddings.embed_query("ping"))  # expected 768

        # --- Pinecone setup ---
        pc = Pinecone(api_key=pinecone_key)
        self.index_name = index_name

        def _idx_name(x):
            return x.name if hasattr(x, "name") else (x.get("name") if isinstance(x, dict) else None)

        existing = {_idx_name(i) for i in pc.list_indexes()}
        if self.index_name not in existing:
            pc.create_index(
                name=self.index_name,
                dimension=dim,
                metric="cosine",
                spec=ServerlessSpec(cloud=cloud, region=region),
            )

        idx = pc.Index(self.index_name)
        stats = idx.describe_index_stats()
        namespaces = stats.get("namespaces", {}) or {}
        total = sum(ns.get("vector_count", 0) for ns in namespaces.values()) if namespaces else stats.get("total_vector_count", 0) or 0

        # If index is empty, upsert with sign metadata; otherwise connect
        if total == 0:
            self.vectorstore = PineconeVectorStore.from_documents(
                docs, embedding=self.embeddings, index_name=self.index_name
            )
        else:
            self.vectorstore = PineconeVectorStore(index_name=self.index_name, embedding=self.embeddings)

        # --- LLMs (Groq) ---
        self.chat_llm = ChatGroq(model_name="llama-3.1-8b-instant", temperature=0.3, max_tokens=200)
        self.rewriter_llm = ChatGroq(model_name="llama-3.1-8b-instant", temperature=0.1, max_tokens=96)

        # --- Advanced retrieval (MultiQuery + rerank + compression) ---
        base_retriever = self.vectorstore.as_retriever(search_kwargs={"k": 10})
        mqr = MultiQueryRetriever.from_llm(retriever=base_retriever, llm=self.rewriter_llm)

        reranker_model_name = os.getenv("RERANKER_MODEL", "BAAI/bge-reranker-base")
        cross_encoder = HuggingFaceCrossEncoder(model_name=reranker_model_name)
        reranker = CrossEncoderReranker(model=cross_encoder, top_n=5)

        self.advanced_retriever = ContextualCompressionRetriever(
            base_compressor=reranker,
            base_retriever=mqr,
        )

        # --- Helpers ---
        self.SIGNS = SIGN_LIST

        def extract_sign_from_text(text: str) -> Optional[str]:
            t = (text or "").lower()
            for s in self.SIGNS:
                if s in t:
                    return s
            return None

        # STRICT: keep only docs whose metadata["sign"] matches the target sign
        def filter_docs_by_sign(docs_list: List[Document], sign: Optional[str]) -> List[Document]:
            if not sign:
                return docs_list
            s = sign.lower()
            keep: List[Document] = []
            for d in docs_list:
                md_sign = ((d.metadata or {}).get("sign") or "").lower()
                if md_sign == s:
                    keep.append(d)
            return keep

        def extract_citations(docs_list: List[Document]) -> List[str]:
            cites = []
            for i, d in enumerate(docs_list, 1):
                md = (getattr(d, "metadata", {}) or {})
                src = md.get("source") or md.get("file") or "horoscope.txt"
                cites.append(f"[{i}] {src}")
            return cites

        def format_context(docs_list: List[Document]) -> str:
            return "\n\n".join(
                getattr(d, "page_content", "") for d in docs_list if getattr(d, "page_content", "")
            )

        self.extract_sign_from_text = extract_sign_from_text
        self.filter_docs_by_sign = filter_docs_by_sign
        self.extract_citations = extract_citations
        self.format_context = format_context

        # --- Tools ---
        @tool("lucky_number")
        def lucky_number(name_or_sign: str) -> str:
            """Deterministic 'lucky number' (1-9) from a name or zodiac sign."""
            h = int(hashlib.md5(name_or_sign.strip().lower().encode("utf-8")).hexdigest(), 16)
            num = (h % 9) + 1
            return f"Lucky number for '{name_or_sign}': {num}"

        @tool("now")
        def now(_: str = "") -> str:
            """Current date/time (ISO format)."""
            return datetime.datetime.now().isoformat(timespec="seconds")

        self.lucky_number = lucky_number
        self.now = now

        # --- LangGraph: route → tools/RAG → grade → fallback/generate (sign-aware) ---
        class RAGState(TypedDict, total=False):
            session_id: str
            question: str
            history: List
            route: Literal["TOOLS", "RAG"]
            target_sign: Optional[str]
            docs: List[Document]
            citations: List[str]
            tool_result: Optional[str]
            grounded: bool
            answer: str

        TOOL_KEYWORDS = ("lucky number", "lucky", "today", "date", "time", "now")

        def route_node(state: RAGState) -> RAGState:
            q = (state.get("question") or "")
            route = "TOOLS" if any(k in q.lower() for k in TOOL_KEYWORDS) else "RAG"
            return {**state, "route": route, "target_sign": self.extract_sign_from_text(q)}

        def tools_node(state: RAGState) -> RAGState:
            q = state["question"]
            if "lucky" in q.lower():
                target = self.extract_sign_from_text(q) or q.strip()
                result = self.lucky_number.invoke(target.title())
            else:
                result = self.now.invoke("")
            return {**state, "tool_result": result, "answer": result, "docs": [], "citations": [], "grounded": True}

        def retrieve_node(state: RAGState) -> RAGState:
            q = state["question"]
            sign = state.get("target_sign")
            biased_query = f"[{sign}] {q}" if sign else q  # bias retrieval toward the sign
            docs_list = self.advanced_retriever.invoke(biased_query)
            if sign:
                docs_list = self.filter_docs_by_sign(docs_list, sign)
            return {**state, "docs": docs_list, "citations": self.extract_citations(docs_list)}

        def grade_node(state: RAGState) -> RAGState:
            grounded = len(state.get("docs") or []) > 0
            return {**state, "grounded": grounded}

        def fallback_node(state: RAGState) -> RAGState:
            sign = state.get("target_sign")
            if sign:
                msg = f"I don't have content for {sign.title()} yet. Please add it to the corpus."
            else:
                msg = "I don't know yet. Please add relevant content to the corpus."
            return {**state, "answer": msg}

        rag_prompt = ChatPromptTemplate.from_messages([
            ("system",
             "You are a concise fortune teller.\n"
             "- Only answer about the user's sign: {target_sign} (if provided).\n"
             "- Use ONLY the provided context and tool output.\n"
             "- Answer the user's single question in 1–2 sentences.\n"
             "- Do NOT invent new questions or headings.\n"
             "- If the context is empty or about a different sign, reply exactly: I don't know."),
            MessagesPlaceholder("history"),
            ("human",
             "Context:\n{context}\n\nTool:\n{tool_result}\n\nUser question: {question}\n"
             "Citations (optional): {citations}")
        ])
        generator = rag_prompt | self.chat_llm | StrOutputParser()

        def generate_node(state: RAGState) -> RAGState:
            if state.get("tool_result"):
                return state
            context = self.format_context(state.get("docs", [])) if state.get("docs") else ""
            tool_text = state.get("tool_result", "")
            cites = " ".join(state.get("citations", []))
            answer = generator.invoke({
                "history": state.get("history", []),
                "context": context,
                "tool_result": tool_text,
                "question": state["question"],
                "citations": cites,
                "target_sign": (state.get("target_sign") or ""),
            })
            answer = re.sub(r'(?mi)^\s*Question:.*$', '', answer).strip()
            return {**state, "answer": answer}

        graph = StateGraph(RAGState)
        graph.add_node("route", route_node)
        graph.add_node("tools", tools_node)
        graph.add_node("retrieve", retrieve_node)
        graph.add_node("grade", grade_node)
        graph.add_node("fallback", fallback_node)
        graph.add_node("generate", generate_node)

        graph.add_edge(START, "route")
        graph.add_conditional_edges("route", lambda s: "tools" if s["route"] == "TOOLS" else "retrieve")
        graph.add_edge("tools", END)
        graph.add_edge("retrieve", "grade")
        graph.add_conditional_edges("grade", lambda s: "generate" if s["grounded"] else "fallback")
        graph.add_edge("generate", END)
        graph.add_edge("fallback", END)

        self.graph_app = graph.compile()

        # --- Memory wrapper ---
        self._session_store = {}
        def get_session_history(session_id: str) -> InMemoryChatMessageHistory:
            if session_id not in self._session_store:
                self._session_store[session_id] = InMemoryChatMessageHistory()
            return self._session_store[session_id]

        answer_only = self.graph_app | RunnableLambda(lambda state: state["answer"])
        self.graph_with_memory = RunnableWithMessageHistory(
            answer_only,
            get_session_history,
            input_messages_key="question",
            history_messages_key="history",
        )

        # Back-compat for Streamlit
        class _Invoker:
            def __init__(self, outer):
                self._outer = outer
            def invoke(self, question: str, session_id: str = "default"):
                cfg = {"configurable": {"session_id": session_id}}
                return self._outer.graph_with_memory.invoke({"session_id": session_id, "question": question}, config=cfg)
        self.rag_chain = _Invoker(self)

    # Convenience method
    def answer_with_graph(self, question: str, session_id: str = "default") -> str:
        cfg = {"configurable": {"session_id": session_id}}
        return self.graph_with_memory.invoke({"session_id": session_id, "question": question}, config=cfg)


if __name__ == "__main__":
    bot = ChatBot()
    try:
        q = input("Ask me anything: ")
    except EOFError:
        q = "What can Sagittarius expect this week?"
    ans = bot.answer_with_graph(q, session_id="cli-user")
    print(ans)


You can expect a mix of challenges and opportunities in 2025, with Mercury retrograde periods potentially disrupting your travel plans and communication. However, the Sun's focus on your first house from November to December will bring new connections and experiences, helping you discover fresh perspectives and expand your horizons.


In [3]:
# main.py
import os
import re
from typing import List, Optional, Literal, TypedDict
from collections import Counter

from dotenv import load_dotenv

# Pinecone + Vector store
from pinecone import Pinecone, ServerlessSpec
from langchain_pinecone import PineconeVectorStore

# Embeddings (local HF download; NOT HF inference)
from langchain_huggingface import HuggingFaceEmbeddings

# Data loading / splitting
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.document_loaders import TextLoader

# LangChain core
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.documents import Document
from langchain_core.runnables import RunnableLambda

# Advanced retrieval
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain.retrievers import ContextualCompressionRetriever

# Tools
from langchain_core.tools import tool
import hashlib, datetime

# LangGraph
from langgraph.graph import StateGraph, START, END

# Memory (version-agnostic)
try:
    from langchain_community.chat_message_histories import ChatMessageHistory as InMemoryChatMessageHistory
except ImportError:
    from langchain_community.chat_message_histories import InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory

# LLMs via Groq
from langchain_groq import ChatGroq


class ChatBot:
    def __init__(
        self,
        index_name: str = "langchain-demo-signed-v2",  # NEW name to force clean reindex
        cloud: str = "aws",
        region: str = "us-east-1",
    ):
        load_dotenv()

        # --- Required env ---
        pinecone_key = os.getenv("PINECONE_API_KEY")
        groq_key = os.getenv("GROQ_API_KEY")
        if not pinecone_key:
            raise RuntimeError("Missing PINECONE_API_KEY in environment.")
        if not groq_key:
            raise RuntimeError("Missing GROQ_API_KEY in environment.")

        # --- Load base docs ---
        if not os.path.exists("./horoscope.txt"):
            raise FileNotFoundError("Couldn't find './horoscope.txt'. Make sure the file exists.")
        loader = TextLoader("./horoscope.txt")
        base_docs = loader.load()  # list[Document]

        # --- Tag each BASE document once with its dominant sign (no nulls) ---
        SIGN_LIST = [
            "aries","taurus","gemini","cancer","leo","virgo",
            "libra","scorpio","sagittarius","capricorn","aquarius","pisces"
        ]
        sign_re = re.compile(r"\b(" + "|".join(SIGN_LIST) + r")\b", flags=re.I)

        for d in base_docs:
            text = d.page_content or ""
            found = [s.lower() for s in sign_re.findall(text)]
            d.metadata = (d.metadata or {})
            if found:
                dominant = Counter(found).most_common(1)[0][0]
                d.metadata["sign"] = dominant      # e.g., "sagittarius"
            else:
                d.metadata.pop("sign", None)       # ensure no null/None
            d.metadata.setdefault("source", "horoscope.txt")

        # --- Split AFTER tagging so chunks inherit metadata['sign'] ---
        text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=4)
        docs = text_splitter.split_documents(base_docs)

        # --- Embeddings ---
        self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
        dim = len(self.embeddings.embed_query("ping"))  # expected 768

        # --- Pinecone setup ---
        pc = Pinecone(api_key=pinecone_key)
        self.index_name = index_name

        def _idx_name(x):
            return x.name if hasattr(x, "name") else (x.get("name") if isinstance(x, dict) else None)

        existing = {_idx_name(i) for i in pc.list_indexes()}
        if self.index_name not in existing:
            pc.create_index(
                name=self.index_name,
                dimension=dim,
                metric="cosine",
                spec=ServerlessSpec(cloud=cloud, region=region),
            )

        idx = pc.Index(self.index_name)
        stats = idx.describe_index_stats()
        namespaces = stats.get("namespaces", {}) or {}
        total = sum(ns.get("vector_count", 0) for ns in namespaces.values()) if namespaces else stats.get("total_vector_count", 0) or 0

        # If index is empty, upsert; otherwise connect
        if total == 0:
            self.vectorstore = PineconeVectorStore.from_documents(
                docs, embedding=self.embeddings, index_name=self.index_name
            )
        else:
            self.vectorstore = PineconeVectorStore(index_name=self.index_name, embedding=self.embeddings)

        # --- LLMs (Groq) ---
        self.chat_llm = ChatGroq(model_name="llama-3.1-8b-instant", temperature=0.3, max_tokens=200)
        self.rewriter_llm = ChatGroq(model_name="llama-3.1-8b-instant", temperature=0.1, max_tokens=96)

        # --- Advanced retrieval (MultiQuery + rerank + compression) ---
        base_retriever = self.vectorstore.as_retriever(search_kwargs={"k": 10})
        mqr = MultiQueryRetriever.from_llm(retriever=base_retriever, llm=self.rewriter_llm)

        reranker_model_name = os.getenv("RERANKER_MODEL", "BAAI/bge-reranker-base")
        cross_encoder = HuggingFaceCrossEncoder(model_name=reranker_model_name)
        reranker = CrossEncoderReranker(model=cross_encoder, top_n=5)

        self.advanced_retriever = ContextualCompressionRetriever(
            base_compressor=reranker,
            base_retriever=mqr,
        )

        # --- Helpers ---
        self.SIGNS = SIGN_LIST

        def extract_sign_from_text(text: str) -> Optional[str]:
            t = (text or "").lower()
            for s in self.SIGNS:
                if s in t:
                    return s
            return None

        # STRICT: keep only docs whose metadata["sign"] matches the target sign
        def filter_docs_by_sign(docs_list: List[Document], sign: Optional[str]) -> List[Document]:
            if not sign:
                return docs_list
            s = sign.lower()
            keep: List[Document] = []
            for d in docs_list:
                md_sign = ((d.metadata or {}).get("sign") or "").lower()
                if md_sign == s:
                    keep.append(d)
            return keep

        def extract_citations(docs_list: List[Document]) -> List[str]:
            cites = []
            for i, d in enumerate(docs_list, 1):
                md = (getattr(d, "metadata", {}) or {})
                src = md.get("source") or md.get("file") or "horoscope.txt"
                cites.append(f"[{i}] {src}")
            return cites

        def format_context(docs_list: List[Document]) -> str:
            return "\n\n".join(
                getattr(d, "page_content", "") for d in docs_list if getattr(d, "page_content", "")
            )

        self.extract_sign_from_text = extract_sign_from_text
        self.filter_docs_by_sign = filter_docs_by_sign
        self.extract_citations = extract_citations
        self.format_context = format_context

        # --- Tools ---
        @tool("lucky_number")
        def lucky_number(name_or_sign: str) -> str:
            """Deterministic 'lucky number' (1-9) from a name or zodiac sign."""
            h = int(hashlib.md5(name_or_sign.strip().lower().encode("utf-8")).hexdigest(), 16)
            num = (h % 9) + 1
            return f"Lucky number for '{name_or_sign}': {num}"

        @tool("now")
        def now(_: str = "") -> str:
            """Current date/time (ISO format)."""
            return datetime.datetime.now().isoformat(timespec="seconds")

        self.lucky_number = lucky_number
        self.now = now

        # --- LangGraph: route → tools/RAG → grade → fallback/generate (sign-aware) ---
        class RAGState(TypedDict, total=False):
            session_id: str
            question: str
            history: List
            route: Literal["TOOLS", "RAG"]
            target_sign: Optional[str]
            docs: List[Document]
            citations: List[str]
            tool_result: Optional[str]
            grounded: bool
            answer: str

        TOOL_KEYWORDS = ("lucky number", "lucky", "today", "date", "time", "now")

        def route_node(state: RAGState) -> RAGState:
            q = (state.get("question") or "")
            route = "TOOLS" if any(k in q.lower() for k in TOOL_KEYWORDS) else "RAG"
            return {**state, "route": route, "target_sign": self.extract_sign_from_text(q)}

        def tools_node(state: RAGState) -> RAGState:
            q = state["question"]
            if "lucky" in q.lower():
                target = self.extract_sign_from_text(q) or q.strip()
                result = self.lucky_number.invoke(target.title())
            else:
                result = self.now.invoke("")
            return {**state, "tool_result": result, "answer": result, "docs": [], "citations": [], "grounded": True}

        def retrieve_node(state: RAGState) -> RAGState:
            q = state["question"]
            sign = state.get("target_sign")
            biased_query = f"[{sign}] {q}" if sign else q  # bias retrieval toward the sign
            docs_list = self.advanced_retriever.invoke(biased_query)
            if sign:
                docs_list = self.filter_docs_by_sign(docs_list, sign)
            return {**state, "docs": docs_list, "citations": self.extract_citations(docs_list)}

        def grade_node(state: RAGState) -> RAGState:
            grounded = len(state.get("docs") or []) > 0
            return {**state, "grounded": grounded}

        def fallback_node(state: RAGState) -> RAGState:
            sign = state.get("target_sign")
            if sign:
                msg = f"I don't have content for {sign.title()} yet. Please add it to the corpus."
            else:
                msg = "I don't know yet. Please add relevant content to the corpus."
            return {**state, "answer": msg}

        rag_prompt = ChatPromptTemplate.from_messages([
            ("system",
             "You are a concise fortune teller.\n"
             "- Only answer about the user's sign: {target_sign} (if provided).\n"
             "- Use ONLY the provided context and tool output.\n"
             "- Answer the user's single question in 1–2 sentences.\n"
             "- Do NOT invent new questions or headings.\n"
             "- If the context is empty or about a different sign, reply exactly: I don't know."),
            MessagesPlaceholder("history"),
            ("human",
             "Context:\n{context}\n\nTool:\n{tool_result}\n\nUser question: {question}\n"
             "Citations (optional): {citations}")
        ])
        generator = rag_prompt | self.chat_llm | StrOutputParser()

        def generate_node(state: RAGState) -> RAGState:
            # Double-safety: if no docs survived filtering, do NOT generate.
            if state.get("tool_result"):
                return state
            if not state.get("docs"):
                sign = state.get("target_sign")
                msg = f"I don't have content for {sign.title()} yet. Please add it to the corpus." if sign else "I don't know yet. Please add relevant content to the corpus."
                return {**state, "answer": msg}

            context = self.format_context(state.get("docs", []))
            tool_text = state.get("tool_result", "")
            cites = " ".join(state.get("citations", []))
            answer = generator.invoke({
                "history": state.get("history", []),
                "context": context,
                "tool_result": tool_text,
                "question": state["question"],
                "citations": cites,
                "target_sign": (state.get("target_sign") or ""),
            })
            answer = re.sub(r'(?mi)^\s*Question:.*$', '', answer).strip()
            return {**state, "answer": answer}

        graph = StateGraph(RAGState)
        graph.add_node("route", route_node)
        graph.add_node("tools", tools_node)
        graph.add_node("retrieve", retrieve_node)
        graph.add_node("grade", grade_node)
        graph.add_node("fallback", fallback_node)
        graph.add_node("generate", generate_node)

        graph.add_edge(START, "route")
        graph.add_conditional_edges("route", lambda s: "tools" if s["route"] == "TOOLS" else "retrieve")
        graph.add_edge("tools", END)
        graph.add_edge("retrieve", "grade")
        graph.add_conditional_edges("grade", lambda s: "generate" if s["grounded"] else "fallback")
        graph.add_edge("generate", END)
        graph.add_edge("fallback", END)

        self.graph_app = graph.compile()

        # --- Memory wrapper ---
        self._session_store = {}
        def get_session_history(session_id: str) -> InMemoryChatMessageHistory:
            if session_id not in self._session_store:
                self._session_store[session_id] = InMemoryChatMessageHistory()
            return self._session_store[session_id]

        answer_only = self.graph_app | RunnableLambda(lambda state: state["answer"])
        self.graph_with_memory = RunnableWithMessageHistory(
            answer_only,
            get_session_history,
            input_messages_key="question",
            history_messages_key="history",
        )

        # Back-compat for Streamlit
        class _Invoker:
            def __init__(self, outer):
                self._outer = outer
            def invoke(self, question: str, session_id: str = "default"):
                cfg = {"configurable": {"session_id": session_id}}
                return self._outer.graph_with_memory.invoke({"session_id": session_id, "question": question}, config=cfg)
        self.rag_chain = _Invoker(self)

    # Convenience method
    def answer_with_graph(self, question: str, session_id: str = "default") -> str:
        cfg = {"configurable": {"session_id": session_id}}
        return self.graph_with_memory.invoke({"session_id": session_id, "question": question}, config=cfg)


if __name__ == "__main__":
    bot = ChatBot()
    try:
        q = input("Ask me anything: ")
    except EOFError:
        q = "What can Sagittarius expect this week?"
    ans = bot.answer_with_graph(q, session_id="cli-user")
    print(ans)


I don't have content for Aries yet. Please add it to the corpus.


In [7]:
# main.py
import os
import re
from typing import List, Optional, Literal, TypedDict
from collections import Counter

from dotenv import load_dotenv

# Pinecone + Vector store
from pinecone import Pinecone, ServerlessSpec
from langchain_pinecone import PineconeVectorStore

# Embeddings (local HF download; NOT HF inference)
from langchain_huggingface import HuggingFaceEmbeddings

# Data loading / splitting
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.document_loaders import TextLoader

# LangChain core
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.documents import Document
from langchain_core.runnables import RunnableLambda

# Advanced retrieval
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain.retrievers import ContextualCompressionRetriever

# Tools
from langchain_core.tools import tool
import hashlib, datetime

# LangGraph
from langgraph.graph import StateGraph, START, END

# Memory (version-agnostic)
try:
    from langchain_community.chat_message_histories import ChatMessageHistory as InMemoryChatMessageHistory
except ImportError:
    from langchain_community.chat_message_histories import InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory

# LLMs via Groq
from langchain_groq import ChatGroq


class ChatBot:
    def __init__(
        self,
        index_name: str = "langchain-demo-signed-v2",
        cloud: str = "aws",
        region: str = "us-east-1",
    ):
        load_dotenv()

        # --- Required env ---
        pinecone_key = os.getenv("PINECONE_API_KEY")
        groq_key = os.getenv("GROQ_API_KEY")
        if not pinecone_key:
            raise RuntimeError("Missing PINECONE_API_KEY in environment.")
        if not groq_key:
            raise RuntimeError("Missing GROQ_API_KEY in environment.")

        # --- Load base docs ---
        if not os.path.exists("./horoscope.txt"):
            raise FileNotFoundError("Couldn't find './horoscope.txt'. Make sure the file exists.")
        loader = TextLoader("./horoscope.txt")
        base_docs = loader.load()  # list[Document]

        # --- Tag each BASE document once with its dominant sign (no nulls) ---
        SIGN_LIST = [
            "aries","taurus","gemini","cancer","leo","virgo",
            "libra","scorpio","sagittarius","capricorn","aquarius","pisces"
        ]
        sign_re = re.compile(r"\b(" + "|".join(SIGN_LIST) + r")\b", flags=re.I)

        for d in base_docs:
            text = d.page_content or ""
            found = [s.lower() for s in sign_re.findall(text)]
            d.metadata = (d.metadata or {})
            if found:
                dominant = Counter(found).most_common(1)[0][0]
                d.metadata["sign"] = dominant      # e.g., "sagittarius"
            else:
                d.metadata.pop("sign", None)       # ensure no null/None
            d.metadata.setdefault("source", "horoscope.txt")

        # --- Split AFTER tagging so chunks inherit metadata['sign'] ---
        text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=4)
        docs = text_splitter.split_documents(base_docs)

        # --- Embeddings ---
        self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
        dim = len(self.embeddings.embed_query("ping"))  # expected 768

        # --- Pinecone setup ---
        pc = Pinecone(api_key=pinecone_key)
        self.index_name = index_name

        def _idx_name(x):
            return x.name if hasattr(x, "name") else (x.get("name") if isinstance(x, dict) else None)

        existing = {_idx_name(i) for i in pc.list_indexes()}
        if self.index_name not in existing:
            pc.create_index(
                name=self.index_name,
                dimension=dim,
                metric="cosine",
                spec=ServerlessSpec(cloud=cloud, region=region),
            )

        idx = pc.Index(self.index_name)
        stats = idx.describe_index_stats()
        namespaces = stats.get("namespaces", {}) or {}
        total = sum(ns.get("vector_count", 0) for ns in namespaces.values()) if namespaces else stats.get("total_vector_count", 0) or 0

        # If index is empty, upsert; otherwise connect
        if total == 0:
            self.vectorstore = PineconeVectorStore.from_documents(
                docs, embedding=self.embeddings, index_name=self.index_name
            )
        else:
            self.vectorstore = PineconeVectorStore(index_name=self.index_name, embedding=self.embeddings)

        # --- LLMs (Groq) ---
        self.chat_llm = ChatGroq(model_name="llama-3.1-8b-instant", temperature=0.3, max_tokens=200)
        self.rewriter_llm = ChatGroq(model_name="llama-3.1-8b-instant", temperature=0.1, max_tokens=96)

        # --- Advanced retrieval (MultiQuery + rerank + compression) ---
        base_retriever = self.vectorstore.as_retriever(search_kwargs={"k": 10})
        mqr = MultiQueryRetriever.from_llm(retriever=base_retriever, llm=self.rewriter_llm)

        reranker_model_name = os.getenv("RERANKER_MODEL", "BAAI/bge-reranker-base")
        cross_encoder = HuggingFaceCrossEncoder(model_name=reranker_model_name)
        reranker = CrossEncoderReranker(model=cross_encoder, top_n=5)

        self.advanced_retriever = ContextualCompressionRetriever(
            base_compressor=reranker,
            base_retriever=mqr,
        )

        # --- Helpers ---
        self.SIGNS = SIGN_LIST

        def extract_sign_from_text(text: str) -> Optional[str]:
            t = (text or "").lower()
            for s in self.SIGNS:
                if s in t:
                    return s
            return None

        # STRICT: keep only docs whose metadata["sign"] matches the target sign
        def filter_docs_by_sign(docs_list: List[Document], sign: Optional[str]) -> List[Document]:
            if not sign:
                return docs_list
            s = sign.lower()
            keep: List[Document] = []
            for d in docs_list:
                md_sign = ((d.metadata or {}).get("sign") or "").lower()
                if md_sign == s:
                    keep.append(d)
            return keep

        def extract_citations(docs_list: List[Document]) -> List[str]:
            cites = []
            for i, d in enumerate(docs_list, 1):
                md = (getattr(d, "metadata", {}) or {})
                src = md.get("source") or md.get("file") or "horoscope.txt"
                cites.append(f"[{i}] {src}")
            return cites

        def format_context(docs_list: List[Document]) -> str:
            return "\n\n".join(
                getattr(d, "page_content", "") for d in docs_list if getattr(d, "page_content", "")
            )

        def is_general_question(q: str) -> bool:
            """Heuristic: True for general knowledge/how-to questions (not astrology/tools)."""
            ql = (q or "").lower()
            if any(s in ql for s in self.SIGNS):
                return False
            if any(k in ql for k in ("lucky number", "lucky", "today", "date", "time", "now")):
                return False
            GENERAL_HINTS = ("capital", "recipe", "how to", "who is", "what is", "define",
                             "population", "country", "city", "explain", "difference between")
            return any(h in ql for h in GENERAL_HINTS)

        self.extract_sign_from_text = extract_sign_from_text
        self.filter_docs_by_sign = filter_docs_by_sign
        self.extract_citations = extract_citations
        self.format_context = format_context
        self.is_general_question = is_general_question

        # --- Tools ---
        @tool("lucky_number")
        def lucky_number(name_or_sign: str) -> str:
            """Deterministic 'lucky number' (1-9) from a name or zodiac sign."""
            h = int(hashlib.md5(name_or_sign.strip().lower().encode("utf-8")).hexdigest(), 16)
            num = (h % 9) + 1
            return f"Lucky number for '{name_or_sign}': {num}"

        @tool("now")
        def now(_: str = "") -> str:
            """Current date/time (ISO format)."""
            return datetime.datetime.now().isoformat(timespec="seconds")

        self.lucky_number = lucky_number
        self.now = now

        # --- LangGraph: route → tools/GENERAL/RAG → grade → fallback/generate ---
        class RAGState(TypedDict, total=False):
            session_id: str
            question: str
            history: List
            route: Literal["TOOLS", "GENERAL", "RAG"]
            target_sign: Optional[str]
            docs: List[Document]
            citations: List[str]
            tool_result: Optional[str]
            grounded: bool
            answer: str

        TOOL_KEYWORDS = ("lucky number", "lucky", "today", "date", "time", "now")

        # --- Prompts ---
        rag_prompt = ChatPromptTemplate.from_messages([
            ("system",
             "You are a concise fortune teller.\n"
             "- Only answer about the user's sign: {target_sign} (if provided).\n"
             "- Use ONLY the provided context and tool output.\n"
             "- Answer the user's single question in 1–2 sentences.\n"
             "- Do NOT invent new questions or headings.\n"
             "- If the context is empty or about a different sign, reply exactly: I don't know."),
            MessagesPlaceholder("history"),
            ("human",
             "Context:\n{context}\n\nTool:\n{tool_result}\n\nUser question: {question}\n"
             "Citations (optional): {citations}")
        ])
        generator = rag_prompt | self.chat_llm | StrOutputParser()

        general_prompt = ChatPromptTemplate.from_messages([
            ("system",
             "You are a helpful assistant. Answer the user clearly and concisely in 2–3 sentences."),
            MessagesPlaceholder("history"),
            ("human", "{question}")
        ])
        general_generator = general_prompt | self.chat_llm | StrOutputParser()

        # --- Nodes ---
        def route_node(state: RAGState) -> RAGState:
            q = (state.get("question") or "")
            if any(k in q.lower() for k in TOOL_KEYWORDS):
                route = "TOOLS"
            elif self.is_general_question(q):
                route = "GENERAL"
            else:
                route = "RAG"
            return {**state, "route": route, "target_sign": self.extract_sign_from_text(q)}

        def tools_node(state: RAGState) -> RAGState:
            q = state["question"]
            if "lucky" in q.lower():
                target = self.extract_sign_from_text(q) or q.strip()
                result = self.lucky_number.invoke(target.title())
            else:
                result = self.now.invoke("")
            return {**state, "tool_result": result, "answer": result, "docs": [], "citations": [], "grounded": True}

        def general_node(state: RAGState) -> RAGState:
            ans = general_generator.invoke({
                "history": state.get("history", []),
                "question": state["question"]
            })
            return {**state, "answer": ans, "docs": [], "citations": [], "grounded": False}

        def retrieve_node(state: RAGState) -> RAGState:
            q = state["question"]
            sign = state.get("target_sign")
            biased_query = f"[{sign}] {q}" if sign else q
            docs_list = self.advanced_retriever.invoke(biased_query)
            if sign:
                docs_list = self.filter_docs_by_sign(docs_list, sign)
            return {**state, "docs": docs_list, "citations": self.extract_citations(docs_list)}

        def grade_node(state: RAGState) -> RAGState:
            grounded = len(state.get("docs") or []) > 0
            return {**state, "grounded": grounded}

        def fallback_node(state: RAGState) -> RAGState:
            # Non-general, horoscope-style query with no matching context
            sign = state.get("target_sign")
            if sign:
                msg = f"I don't have content for {sign.title()} yet. Please add it to the corpus."
            else:
                msg = "I don't know yet. Please add relevant content to the corpus."
            return {**state, "answer": msg, "grounded": False}

        def generate_node(state: RAGState) -> RAGState:
            if state.get("tool_result"):
                return state
            if not state.get("docs"):
                return fallback_node(state)

            context = self.format_context(state.get("docs", []))
            tool_text = state.get("tool_result", "")
            cites = " ".join(state.get("citations", []))
            answer = generator.invoke({
                "history": state.get("history", []),
                "context": context,
                "tool_result": tool_text,
                "question": state["question"],
                "citations": cites,
                "target_sign": (state.get("target_sign") or ""),
            })
            answer = re.sub(r'(?mi)^\s*Question:.*$', '', answer).strip()
            return {**state, "answer": answer}

        # --- Graph wiring ---
        graph = StateGraph(RAGState)
        graph.add_node("route", route_node)
        graph.add_node("tools", tools_node)
        graph.add_node("general", general_node)
        graph.add_node("retrieve", retrieve_node)
        graph.add_node("grade", grade_node)
        graph.add_node("fallback", fallback_node)
        graph.add_node("generate", generate_node)

        def _router(s: RAGState) -> str:
            return "tools" if s["route"] == "TOOLS" else ("general" if s["route"] == "GENERAL" else "retrieve")

        graph.add_edge(START, "route")
        graph.add_conditional_edges("route", _router)
        graph.add_edge("tools", END)
        graph.add_edge("general", END)
        graph.add_edge("retrieve", "grade")
        graph.add_conditional_edges("grade", lambda s: "generate" if s["grounded"] else "fallback")
        graph.add_edge("generate", END)
        graph.add_edge("fallback", END)

        self.graph_app = graph.compile()

        # --- Memory wrapper (ephemeral, in-process) ---
        self._session_store = {}
        def get_session_history(session_id: str) -> InMemoryChatMessageHistory:
            if session_id not in self._session_store:
                self._session_store[session_id] = InMemoryChatMessageHistory()
            return self._session_store[session_id]

        answer_only = self.graph_app | RunnableLambda(lambda state: state["answer"])
        self.graph_with_memory = RunnableWithMessageHistory(
            answer_only,
            get_session_history,
            input_messages_key="question",
            history_messages_key="history",
        )

        # Back-compat for Streamlit
        class _Invoker:
            def __init__(self, outer):
                self._outer = outer
            def invoke(self, question: str, session_id: str = "default"):
                cfg = {"configurable": {"session_id": session_id}}
                return self._outer.graph_with_memory.invoke({"session_id": session_id, "question": question}, config=cfg)
        self.rag_chain = _Invoker(self)

    # Convenience method
    def answer_with_graph(self, question: str, session_id: str = "default") -> str:
        cfg = {"configurable": {"session_id": session_id}}
        return self.graph_with_memory.invoke({"session_id": session_id, "question": question}, config=cfg)


if __name__ == "__main__":
    bot = ChatBot()
    try:
        q = input("Ask me anything: ")
    except EOFError:
        q = "What can Sagittarius expect this week?"
    ans = bot.answer_with_graph(q, session_id="cli-user")
    print(ans)


This year, 2025, will bring opportunities for growth and transformation, especially in your relationships and emotional experiences. However, be mindful of Mercury retrograde periods that might disrupt your plans and communication.


In [8]:
#chat history
from main import ChatBot
bot = ChatBot()

sid = "demo-1"
print(bot.list_sessions())                # []
bot.answer_with_graph("Hi!", sid)
bot.answer_with_graph("Remember this.", sid)

print(bot.get_history(sid))               # show messages

bot.save_history(sid, "demo-1.json")      # persist
bot.clear_history(sid)                    # wipe in-RAM
bot.load_history(sid, "demo-1.json")      # restore
print(bot.get_history(sid))

[]
[{'type': 'human', 'data': {'content': 'Hi!', 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'human', 'name': None, 'id': None, 'example': False}}, {'type': 'ai', 'data': {'content': 'You are a Sagittarius. This year, your journey of discovering new places and meeting people who can offer a fresh perspective on the world and yourself will be highlighted, making it a perfect time to seek deeper connections and expand your horizons.', 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'ai', 'name': None, 'id': None, 'example': False, 'tool_calls': [], 'invalid_tool_calls': [], 'usage_metadata': None}}, {'type': 'human', 'data': {'content': 'Remember this.', 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'human', 'name': None, 'id': None, 'example': False}}, {'type': 'ai', 'data': {'content': 'This year, your journey of discovering new places and meeting people who can offer a fresh perspective on the world and yourself will be highlighted, making it a perf

In [None]:
# The chat memory is kept in RAM, locally, inside your Python process (that _session_store dict).

# If you restart Python/Streamlit or the process crashes, that in-RAM memory is gone—unless you save it to a file (save_history(...)).

# save_history(...) writes a JSON file on disk (e.g., demo-1.json). load_history(...) pulls it back into RAM.

# In your Streamlit app:

#     Reset session = new session_id ⇒ starts a fresh memory thread.

#     Clear chat button only clears the UI list; to truly wipe memory call clear_history(session_id)
# Pinecone stores your vectors in the cloud (persistent).

# Groq runs the LLM remotely (stateless).