In [1]:
import os

print(os.getcwd())
os.chdir("../")
print(os.getcwd())


/home/ryefoxlime/MathTutor/notebook
/home/ryefoxlime/MathTutor


In [2]:
import datetime
import os
import uuid
from pathlib import Path
from typing import Any

from dotenv import load_dotenv
from fastapi import HTTPException
from guardrails.errors import ValidationError
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph

import inngest
from guard.guardrails import InputGuard, OutputGuard
from src.Math import logger
from src.Math.components.data_ingestion import DataLoader
from src.Math.components.data_storing import QdrantStorage
from src.Math.config.configuration import ConfigurationManager
from src.Math.entity.config_entity import (
    GraphState,
)

load_dotenv()


True

In [3]:
FEEDBACK_FILE = "feedback.jsonl"

config_manager = ConfigurationManager()
data_ingestion_config = config_manager.get_data_ingestion_config()
qdrant_config = config_manager.get_data_storing_params()

data_loader = DataLoader(data_ingestion_config)
qdrant_storage = QdrantStorage(config=qdrant_config)

input_guard = InputGuard()
output_guard = OutputGuard()

model_name = config_manager.config.models[0].parameters.model
base_url = config_manager.config.models[0].parameters.base_url
api_key = os.getenv("OPENROUTER_API_KEY")

# Initialize Inngest client
inngest_client = inngest.Inngest(
    app_id="rag_app",
    logger=logger,
    is_production=False,
    serializer=inngest.PydanticSerializer(),
)

SUMMARY_THRESHOLD = 102400

RELEVANCE_THRESHOLD = 0.70

[2025-11-09 16:58:51,891: INFO: common]: yaml file: config/config.yaml loaded successfully]
[2025-11-09 16:58:51,893: INFO: common]: yaml file: params.yaml loaded successfully]


In [4]:
def prepare_context(state: GraphState) -> dict[str]:
            """Prepare context from conversation history."""
            logger.debug("Running prepare_context node")
            history = state.history
            question = state.question
            current_history_tokens = state.history_tokens

            if not history:
                return {"summary": "", "history_tokens": 0}

            logger.info(f"Context threshold set to {SUMMARY_THRESHOLD} tokens.")

            if current_history_tokens <= SUMMARY_THRESHOLD:
                logger.info("History is below threshold. Using full history.")
                prompt_history = []
                for msg in history:
                    role = "User" if msg["sender"] == "user" else "Assistant"
                    prompt_history.append(f"{role}: {msg['text']}")
                history_str = "\n".join(prompt_history)

                return {
                    "summary": history_str,
                    "history_tokens": current_history_tokens,
                }
            else:
                logger.info("History is over threshold. Summarizing...")
                prompt_history = []
                for msg in history:
                    role = "User" if msg["sender"] == "user" else "Assistant"
                    prompt_history.append(f"{role}: {msg['text']}")
                history_str = "\n".join(prompt_history)

                summary_prompt = (
                    "You are a helpful summarization assistant. "
                    "Condense the following conversation into a short paragraph. "
                    "Maintain Key Topics from the conversation especially any "
                    "formulas or transformations. "
                    "Focus on the main topics and any information relevant to the "
                    "user's *new* question.\n\n"
                    "--- CONVERSATION ---\n"
                    f"{history_str}\n\n"
                    "--- USER'S NEW QUESTION ---\n"
                    f"{question}\n\n"
                    "--- CONCISE SUMMARY ---\n"
                )
                try:
                    summarizer_llm = ChatOpenAI(
                        model=model_name,
                        base_url=base_url,
                        api_key=api_key,
                    )
                    # No async needed here, it's a small internal call
                    response = summarizer_llm.invoke(summary_prompt)

                    new_summary = getattr(response, "content", str(response))
                    new_token_count = response.response_metadata["token_usage"][
                        "total_tokens"
                    ]

                    logger.info(
                        f"Summarization complete. New token count: {new_token_count}"
                    )
                    return {"summary": new_summary, "history_tokens": new_token_count}

                except Exception as e:
                    logger.error(f"Summarization failed: {e}")
                    return {"summary": "", "history_tokens": 0}

In [5]:
def validate_question(state: GraphState) -> dict[str]:
    """Validate that the question is math-related."""
    logger.info("Running validate_question node")
    question = state.question
    try:
        input_guard.validate(text_to_validate=question)
        logger.info("Input validation passed.")
        return {"is_valid": True}
    except ValidationError as e:
        logger.error("Input validation failed: %s", e)
        return {
            "is_valid": False,
            "generation": "I can only answer math-related questions.",
        }


In [6]:
def retrieve(state: GraphState) -> dict[str]:
            """Retrieve relevant documents from vector store."""
            question = state.question
            logger.info("Retrieving documents for question: %s", question[:50])

            query_vec = data_loader.embed_query(question)

            # --- MODIFICATION ---
            # Pass the RELEVANCE_THRESHOLD to the search method.
            # data_storing.py will now try to use this for DB-level filtering.
            found = qdrant_storage.search(
                query_vec, 5, score_threshold=RELEVANCE_THRESHOLD
            )

            if not found or not found.get("contexts"):
                logger.warning("No search results returned from Qdrant")
                return {"documents": [], "is_kb_relevant": False}

            documents = found.get("contexts", [])
            scores = found.get("scores", [])  # Get scores

            if not scores:
                # Fallback if search method returns no scores
                logger.warning("No scores returned from Qdrant. Using simple check.")
                is_kb_relevant = bool(
                    documents and any(doc.strip() for doc in documents)
                )
            else:
                top_score = scores[0] if scores else 0
                logger.info(f"Top document score: {top_score}")

                # --- IMPORTANT ---
                # We KEEP this application-level filter.
                # Why? In case the 'score_threshold' argument failed silently
                # (due to the old client), this ensures we still filter out
                # irrelevant results.
                if top_score >= RELEVANCE_THRESHOLD:
                    is_kb_relevant = True
                    # Filter documents to only include those above the threshold
                    documents = [
                        doc
                        for doc, score in zip(documents, scores)
                        if score >= RELEVANCE_THRESHOLD
                    ]
                    logger.info(
                        f"Found {len(documents)} relevant docs above threshold."
                    )
                else:
                    is_kb_relevant = False
                    documents = []  # Discard irrelevant documents
                    logger.info("Top score below threshold. Ignoring KB.")

            logger.info("KB relevance: %s", is_kb_relevant)

            return {
                "documents": documents,
                "is_kb_relevant": is_kb_relevant,
            }

In [7]:
def should_web_search(state: GraphState) -> str:
    """Decide whether to use web search or generate from KB."""
    is_kb_relevant = state.is_kb_relevant
    logger.debug("should_web_search check: is_kb_relevant=%s", is_kb_relevant)

    if is_kb_relevant:
        return "prepare_context"
    return "web_search"

In [9]:
def generate(state: GraphState) -> dict[str]:
            """Generate answer using LLM with output guardrails."""
            logger.debug("Running generate node")
            question = state.question
            documents = state.documents
            summary = state.summary
            current_history_tokens = state.history_tokens

            valid_docs = [doc for doc in documents if doc and doc.strip()]
            context_str = "\n\n".join(valid_docs) if valid_docs else ""

            if context_str:
                prompt = (
                    "You are a helpful math assistant. Use the following context to "
                    "answer the question accurately.\n"
                    "If the context doesn't contain enough information, say "
                    '"I don\'t have enough information to answer this question."\n\n'
                    f"Context:\n{context_str}\n\n"
                    f"Conversation History:\n{summary}\n\n"
                    f"Question:\n{question}\n\n"
                    "Answer:\n"
                )
            else:
                prompt = (
                    "You are a helpful math assistant. Answer the following question.\n"
                    "If you don't know the answer, say "
                    '"I don\'t have enough information to answer this question."\n\n'
                    f"Conversation History:\n{summary}\n\n"
                    f"Question:\n{question}\n\n"
                    "Answer:\n"
                )

            try:
                llm = ChatOpenAI(
                    model=model_name, base_url=base_url, api_key=api_key
                )
                # Use ainovke for async LLM call
                response = llm.ainvoke(prompt)
                generation_cost = response.response_metadata["token_usage"][
                    "total_tokens"
                ]
                new_total_tokens = current_history_tokens + generation_cost
                content = getattr(response, "content", str(response))

                output_guard.validate(text_to_validate=content)
                logger.info(
                    "LLM generation complete and validated (len=%d)", len(content)
                )
                return {"generation": content, "history_tokens": new_total_tokens}

            except ValidationError as ve:
                logger.warning("Output Guardrail Failed: %s", ve)
                return {
                    "generation": str(ve),
                    "history_tokens": current_history_tokens,
                }
            except Exception as exc:
                logger.exception("LLM generation failed: %s", exc)
                return {
                    "generation": (
                        "An error occurred while generating the answer. "
                        "Please try again."
                    ),
                    "history_tokens": current_history_tokens,
                }

In [10]:
workflow = StateGraph(GraphState)

workflow.add_node("prepare_context", prepare_context)
workflow.add_node("validate_question", validate_question)
workflow.add_node("retrieve", retrieve)
workflow.add_node("should_web_search", should_web_search)  # This is now async
workflow.add_node("generate", generate)  # This is now async

workflow.set_entry_point("validate_question")

def after_validation(state: GraphState) -> str:
    return "retrieve" if state.is_valid else END

workflow.add_conditional_edges(
    "validate_question",
    after_validation,
    {END: END, "retrieve": "retrieve"},
)
workflow.add_conditional_edges(
    "retrieve",
    should_web_search,
    {"should_web_search": "should_web_search", "prepare_context": "prepare_context"},
)
workflow.add_edge("should_web_search", END)
workflow.add_edge("prepare_context", "generate")
workflow.add_edge("generate", END)

app_graph = workflow.compile()

In [17]:
def query(question: str) -> dict[str, Any]:
    """Query the RAG pipeline directly without Inngest."""
    if not question:
        raise ValueError("Question cannot be empty")

    inputs = {
        "question": question,
        "is_kb_relevant": False,
    }

    try:
        # This should now work correctly with the async nodes
        final_state = app_graph.ainvoke(inputs)
    except Exception as e:
        logger.error(f"Graph ainvoke failed: {e}", exc_info=True)
        return {
            "answer": f"An error occurred during pipeline execution: {e}",
            "sources": [],
        }

    generation = final_state.get("generation")

    return {"answer": generation}