In [None]:
from langchain_chroma import Chroma
from langchain_core.tools import tool
from langchain_core.documents import Document
from langchain_community.utilities import ArxivAPIWrapper, WikipediaAPIWrapper
from langchain_community.tools import ArxivQueryRun, WikipediaQueryRun
from langchain_tavily import TavilySearch
from langchain_core.messages import HumanMessage, BaseMessage, SystemMessage
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import InMemorySaver
import sqlite3
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph.message import add_messages
from typing import TypedDict, Annotated, Literal, Optional, Dict
from pydantic import BaseModel, Field
from langchain_groq import ChatGroq
import os
from dotenv import load_dotenv
from langchain_community.document_loaders import PyPDFLoader
from langchain_ollama import OllamaEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate



In [2]:
# initializing the models
load_dotenv()
model = ChatGroq(
    model=os.getenv("model"),
    api_key=os.getenv("api_key")
)

embedding_model = OllamaEmbeddings(model="llama3.2")

In [None]:
loader = PyPDFLoader(file_path="Attention.pdf")
docs = loader.load()
print(len(docs))

11


In [20]:
text_splitter = RecursiveCharacterTextSplitter(chunk_size = 200, chunk_overlap = 80)
chunks = text_splitter.split_documents(docs)

print(len(chunks))

234


In [21]:
vector_store = Chroma.from_documents(
    embedding=embedding_model,
    documents=chunks,
    collection_name= "my_collection"
)

In [22]:
retrievers = vector_store.as_retriever(search_type = "mmr", search_kwargs = {'k' : 3, 'lambda_mult' : 0.7})


In [23]:
class Format(BaseModel):
    output :list[str]

op_format = PydanticOutputParser(pydantic_object=Format)

In [29]:
class RAG_state(BaseModel):
    latest_question : Annotated[list[HumanMessage], add_messages]
    sub_questions : Annotated[list[str], add_messages]
    answers : Annotated[list[str], add_messages]
    final_answer : str
    retry_count : int
    evaluation : float = Field(ge=0, le=1)
    clarification : str


In [30]:
# # I need sqlite db to store the checkpoints to make my agent remember the past topics
# conn = sqlite3.connect(database='RAG.db', check_same_thread=False)
# # this creates the checkpointer that saves the checkpoint in the database and allows persistance
# checkpointer = SqliteSaver(conn=conn)

In [None]:
# all node's task
def query_decomposition(state : RAG_state):
    # first when we get the prompt, first we decompose the main question so that we can make multiple diferent questions in the same sense to extract the relevant contnet from the RAG
    template = PromptTemplate(template="""You are a query decomposition agent.
    Your task is to transform a single high-level user question into 2â€“3 concise, non-overlapping sub-questions optimized for semantic search and retrieval.
    Rules:
    - Each sub-question must target a distinct angle of the original question.
    - Sub-questions must be factual, concrete, and searchable.
    - Avoid rephrasing the same question multiple times.
    - Do NOT answer the questions.
    - Do NOT add explanations.
    - Keep each sub-question under 15 words only.
    - Prefer "what / how / why" formulations when useful.
    Question: 
    {question}
    Output format:
    {fm}""", input_variables=['question'],partial_variables={'fm' : op_format.get_format_instructions()})
    chain = template | model | op_format
    output = chain.invoke({"question" : state['latest_question']}).output
    return {'sub_question' : [output]}

In [None]:
graph = StateGraph(RAG_state)

graph.add_node("query_decomposition", query_decomposition)
graph.add_node("Retriever", retriever)