In [1]:
# Run once if needed
!pip install -U langchain langchain-community langgraph langchain-openai faiss-cpu tiktoken




In [2]:
import os
import glob
from typing import List, Literal, Annotated, TypedDict, Sequence

from langchain_core.messages import (
    BaseMessage,
    HumanMessage,
)
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.tools import tool

from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter

from langchain_openai import ChatOpenAI, OpenAIEmbeddings

from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.checkpoint.memory import MemorySaver





In [2]:
import os, glob
from typing import List, Literal, Annotated, Sequence, TypedDict

from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.tools import tool

from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter

from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_huggingface import HuggingFaceEndpointEmbeddings

from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.checkpoint.memory import MemorySaver





In [None]:
#hf api key
base_llm = HuggingFaceEndpoint(
    repo_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
    task="text-generation",
    max_new_tokens=600,
    temperature=0.1,
)

llm = ChatHuggingFace(llm=base_llm)

embeddings = HuggingFaceEndpointEmbeddings(
    repo_id="sentence-transformers/all-mpnet-base-v2",
    task="feature-extraction",
)


In [6]:
DATA_DIR = r"C:\Users\PMLS\Downloads\langgraph\lmkr_data"
FAISS_DIR = r"C:\Users\PMLS\Downloads\langgraph\lmkr_faiss"

def load_docs(path):
    splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)
    docs = []

    for f in glob.glob(os.path.join(path, "*.txt")):
        text = open(f, encoding="utf-8", errors="ignore").read()
        for i, c in enumerate(splitter.split_text(text)):
            docs.append(
                Document(
                    page_content=c,
                    metadata={"source": f, "chunk": i},
                )
            )
    return docs

docs = load_docs(DATA_DIR)


In [7]:
if os.path.exists(FAISS_DIR):
    vs = FAISS.load_local(
        FAISS_DIR,
        embeddings,
        allow_dangerous_deserialization=True,
    )
else:
    vs = FAISS.from_documents(docs, embeddings)
    vs.save_local(FAISS_DIR)

retriever = vs.as_retriever(search_kwargs={"k": 8})


In [8]:
def format_docs(docs: List[Document]) -> str:
    return "\n\n".join(d.page_content for d in docs)

@tool
def lmkr_retriever(query: str) -> str:
    """Retrieve LMKR static documents."""
    docs = retriever.invoke(query)
    return format_docs(docs)


In [9]:
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]


In [10]:
def agent_node(state: AgentState):
    print("---CALL AGENT---")
    model = llm.bind_tools([lmkr_retriever])
    response = model.invoke(state["messages"])
    return {"messages": [response]}


In [11]:
def grade_documents(state: AgentState) -> Literal["generate", "rewrite"]:
    question = state["messages"][0].content
    docs = state["messages"][-1].content

    prompt = PromptTemplate(
        template=(
            "Question:\n{question}\n\n"
            "Documents:\n{context}\n\n"
            "Are the documents relevant? Reply yes or no."
        ),
        input_variables=["question", "context"],
    )

    decision = (
        prompt | llm | StrOutputParser()
    ).invoke({"question": question, "context": docs}).lower()

    return "generate" if decision.startswith("yes") else "rewrite"


In [12]:
def rewrite_node(state: AgentState):
    question = state["messages"][0].content

    prompt = ChatPromptTemplate.from_messages([
        ("system", "Rewrite the question to improve retrieval."),
        ("human", question),
    ])

    rewritten = (prompt | llm | StrOutputParser()).invoke({})
    return {"messages": [HumanMessage(content=rewritten)]}


In [13]:
def generate_node(state: AgentState):
    question = state["messages"][0].content
    docs = state["messages"][-1].content

    prompt = ChatPromptTemplate.from_messages([
        ("system", "Answer using the provided documents only."),
        ("human", "Context:\n{context}\n\nQuestion:\n{question}"),
    ])

    answer = (prompt | llm | StrOutputParser()).invoke(
        {"context": docs, "question": question}
    )

    return {"messages": [answer]}


In [14]:
workflow = StateGraph(AgentState)

workflow.add_node("agent", agent_node)
workflow.add_node("retrieve", ToolNode([lmkr_retriever]))
workflow.add_node("rewrite", rewrite_node)
workflow.add_node("generate", generate_node)

workflow.add_edge(START, "agent")

workflow.add_conditional_edges(
    "agent",
    tools_condition,
    {
        "tools": "retrieve",
        END: END,
    },
)

workflow.add_conditional_edges(
    "retrieve",
    grade_documents,
)

workflow.add_edge("rewrite", "agent")
workflow.add_edge("generate", END)

app = workflow.compile(checkpointer=MemorySaver())


In [15]:
def ask(q):
    return app.invoke(
        {"messages": [HumanMessage(content=q)]},
        config={"thread_id": "agentic_rag_hf"},
    )["messages"][-1].content

ask("What products does LMKR offer?")


---CALL AGENT---


BadRequestError: (Request ID: Root=1-69405a4c-1ba46797462e0ecc07b5a55b;6ff28be3-5ff8-48d3-9693-fa4304963a89)

Bad request:

In [16]:
from IPython.display import HTML, display

mermaid_src = app.get_graph().draw_mermaid()

display(HTML(f"""
<div id="mermaid-container"></div>

<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.1.10/require.min.js"></script>
<script>
require.config({{
    paths: {{
        mermaid: "https://cdn.jsdelivr.net/npm/mermaid@10/dist/mermaid.min"
    }}
}});

require(["mermaid"], function(mermaidLib) {{
    mermaidLib.initialize({{
        startOnLoad: false,
        theme: "default",
        flowchart: {{ curve: "linear" }}
    }});

    document.getElementById("mermaid-container").innerHTML =
        `<pre class="mermaid">{mermaid_src}</pre>`;

    mermaidLib.run();
}});
</script>
"""))
