# 6. Integración de query rewriting y reranking en el flujo RAG

Se integran **query rewriting** y **reranking** en un flujo completo usando **LangGraph**: el pipeline se modela como un grafo de estados con etapas claras (reescritura → recuperación → reranking → generación). Cada nodo transforma un estado compartido (**RAGState**).

⚠️ **Requisito**: Ejecute antes los notebooks `03_rag_base.ipynb` y `05_reranking.ipynb` (o al menos 03) para tener la carpeta `faiss_index`. Este notebook carga índice, retriever, modelo de lenguaje, prompt de reescritura y cross-encoder.

In [None]:
from langgraph.graph import StateGraph, END
from langchain_core.prompts import ChatPromptTemplate
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.vectorstores import FAISS
from sentence_transformers import CrossEncoder
from dotenv import load_dotenv
from typing import TypedDict, List
import os

load_dotenv()
os.environ["GOOGLE_API_KEY"] = os.getenv("GOOGLE_API_KEY")

embeddings_model = GoogleGenerativeAIEmbeddings(model="models/gemini-embedding-001")
language_model = ChatGoogleGenerativeAI(model="gemini-2.5-flash")

vectorstore = FAISS.load_local("./faiss_index", embeddings_model, allow_dangerous_deserialization=True)
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 10})

system_rewrite_prompt = """You are a helpful assistant that generates multiple search queries based on a single input query.

Perform query expansion. If there are multiple common ways of phrasing a user query
or common synonyms for key words in the query, make sure to return multiple versions
of the query with the different phrasings.

If there are acronyms or words you are not familiar with, do not try to rephrase them.

Return exactly 3 different rewritten versions of the query.
Do not include explanations, commentary, or any other text besides the numbered rewritten queries."""

cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
print("Modelos y retriever listos.")

## Estado compartido (RAGState)

Contrato de datos entre nodos: consulta original, consultas reescritas, documentos (texto), fuentes y respuesta final.

In [None]:
class RAGState(TypedDict):
    """Estado compartido del flujo RAG en LangGraph."""
    query: str
    rewrited_queries: List[str]
    documents: List[str]
    sources: List[str]
    answer: str

## Nodo: Query rewriting

Zero-shot: la consulta original se transforma en varias reformulaciones; la salida se parsea por líneas (eliminando numeración) para obtener la lista de subconsultas.

In [None]:
def rewrite_query(state: RAGState) -> RAGState:
    zero_shot_prompt = ChatPromptTemplate.from_messages([
        ("system", system_rewrite_prompt),
        ("human", "{question}")
    ])
    chain = zero_shot_prompt | language_model
    response = chain.invoke({"question": state["query"]})

    lines = response.content.split("\n")
    rewrited_queries = []
    for line in lines:
        line = line.strip()
        if line:
            content = line[3:].strip() if len(line) >= 3 else line
            if content:
                rewrited_queries.append(content)

    return {
        "query": state["query"],
        "rewrited_queries": rewrited_queries,
        "documents": [],
        "sources": [],
        "answer": ""
    }

## Nodo: Recuperación y deduplicación

Se ejecuta el retriever para cada consulta reescrita; se consolidan documentos únicos por contenido y se guardan fuentes.

In [None]:
def retrieve_documents(state: RAGState) -> RAGState:
    query = state["query"]
    rewrited_queries = state["rewrited_queries"]
    seen_contents = set()
    documents = []
    sources = []

    for subquery in rewrited_queries:
        docs = retriever.invoke(subquery)
        for doc in docs:
            content = doc.page_content
            if content not in seen_contents:
                seen_contents.add(content)
                documents.append(content)
                sources.append(doc.metadata.get("source_file", "unknown"))

    return {
        "query": query,
        "rewrited_queries": rewrited_queries,
        "documents": documents,
        "sources": sources,
        "answer": ""
    }

## Nodo: Reranking

Se evalúa cada par (consulta **original**, documento) con el cross-encoder, se ordena por puntaje y se conservan los top-k (p. ej. 3) documentos y sus fuentes.

In [None]:
def rerank_documents(state: RAGState) -> RAGState:
    documents = state["documents"]
    sources = state["sources"]
    query = state["query"]

    pairs = [(query, doc) for doc in documents]
    scores = cross_encoder.predict(pairs)
    scored_docs = list(zip(documents, sources, scores))
    reranked = sorted(scored_docs, key=lambda x: x[2], reverse=True)

    top_k = 3
    reranked_docs = [doc for doc, _, _ in reranked[:top_k]]
    reranked_sources = [src for _, src, _ in reranked[:top_k]]

    return {
        "query": query,
        "rewrited_queries": state["rewrited_queries"],
        "documents": reranked_docs,
        "sources": reranked_sources,
        "answer": ""
    }

## Nodo: Generación de la respuesta

Se construye el contexto con los documentos rerankeados y se invoca el modelo con un prompt que restringe la respuesta al contexto.

In [None]:
prompt = ChatPromptTemplate.from_template("""
Eres un asistente especializado en responder preguntas sobre documentación técnica de desarrollo
de software. Utiliza únicamente la información del contexto proporcionado para responder la pregunta.
Si no conoces la respuesta basándote en el contexto, indica claramente que no tienes esa información.

Pregunta: {question}

Contexto: {context}

Respuesta:
""")

def generate_answer(state: RAGState) -> RAGState:
    context = "\n\n".join(state["documents"])
    chain = prompt | language_model
    response = chain.invoke({
        "context": context,
        "question": state["query"]
    })
    return {
        **state,
        "answer": response.content
    }

## Grafo LangGraph y compilación

Secuencia: rewrite_query → retrieve_documents → rerank_documents → generate_answer → END.

In [None]:
graph = StateGraph(RAGState)
graph.add_node("rewrite_query", rewrite_query)
graph.add_node("retrieve_documents", retrieve_documents)
graph.add_node("rerank_documents", rerank_documents)
graph.add_node("generate_answer", generate_answer)

graph.set_entry_point("rewrite_query")
graph.add_edge("rewrite_query", "retrieve_documents")
graph.add_edge("retrieve_documents", "rerank_documents")
graph.add_edge("rerank_documents", "generate_answer")
graph.add_edge("generate_answer", END)

rag_app = graph.compile()
print("Grafo RAG compilado.")

## Interfaz ask_rag y demostración

Inicializa el estado, ejecuta el grafo y devuelve el estado final.

In [None]:
def ask_rag(question: str) -> RAGState:
    initial_state: RAGState = {
        "query": question,
        "rewrited_queries": [],
        "documents": [],
        "sources": [],
        "answer": ""
    }
    return rag_app.invoke(initial_state)

In [None]:
result = ask_rag("¿Qué son las APIs REST y cuáles son sus principios fundamentales?")

print("\nRESPUESTA GENERADA POR EL SISTEMA RAG")
print("=" * 80)
print(f"\nPregunta original:\n{result['query']}")
print("\nConsultas reescritas:" + "-" * 80)
for i, q in enumerate(result["rewrited_queries"], 1):
    print(f"{i}. {q}")
print("\nRespuesta:" + "-" * 80)
print(result["answer"])
print("\nFuentes utilizadas:" + "-" * 80)
for i, source in enumerate(result["sources"], 1):
    print(f"{i}. {source}")