In [None]:
from typing import TypedDict, Annotated
import operator

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableMap, RunnablePassthrough
from dotenv import load_dotenv
load_dotenv()


In [None]:
class AgentState(TypedDict):
    query: str
    answer: str

In [None]:
loader = TextLoader("sample.txt")  # Add your own .txt file here
docs = loader.load()

splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10)
chunks = splitter.split_documents(docs)

embedding_model = OpenAIEmbeddings()
vector_store = FAISS.from_documents(documents=chunks, embedding=embedding_model)


In [None]:
retriever=vector_store.as_retriever(
    search_type='mmr',# <-- This enables MMR
    search_kwargs={ 'k' : 3 , 'lambda_mult':0.5 } # k= top results , lambda_mult= relevance -diversity balance
)

In [None]:
model= ChatOpenAI(model="gpt-3.5-turbo", temperature=0)

parser= StrOutputParser()

In [None]:
from langchain.prompts import PromptTemplate

prompt = PromptTemplate.from_template(
    "Given the following context:\n\n{context}\n\nAnswer the question:\n{question}"
)


In [None]:
chain = (
    RunnableMap({
        "context": lambda x: retriever.invoke(x["query"]),
        "question": lambda x: x["query"]
    })
    | prompt
    | model
    | parser
)


In [None]:
# Node 1: just pass query
def input_node(state: AgentState) -> AgentState:
    return {"query": state["query"]}

# Node 2: retrieve + answer
def retrieval_node(state: AgentState) -> AgentState:
    response = chain.invoke(state)
    return {"answer": response}


graph = StateGraph(AgentState)
graph.add_node("input", input_node)
graph.add_node("retrieval", retrieval_node)


graph.set_entry_point("input")
graph.add_edge("input", "retrieval")
graph.set_finish_point("retrieval")

app = graph.compile()

In [None]:
result = app.invoke({
    "query": "What is LangGraph?",
    "answer": ""
})

print(result)
print(result["answer"])