In [None]:
# 1) Setup & Installs
!pip -q install langchain langchain-community langchain-text-splitters \
               faiss-cpu sentence-transformers gpt4all gradio requests numexpr


In [None]:
# 2) Upload your files
from google.colab import files
uploaded = files.upload()  # Select the six files from your computer
print("Uploaded:", list(uploaded.keys()))


In [None]:
# 3) Ingestion: load → chunk → embed → index
import os
from typing import List, Dict, Tuple
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings

# Ensure all expected files exist in the current working directory
EXPECTED_FILES = [
    "amazon_sales_stats.txt",
    "amazon_echo_dot_reviews.txt",
    "amazon_echo_dot_manual.txt",
    "amazon_policy_guidelines.txt",
    "amazon_echo_dot_specs.txt",
    "amazon_faq.txt",
]
missing = [f for f in EXPECTED_FILES if not os.path.exists(f)]
if missing:
    raise FileNotFoundError(f"Missing files: {missing}. Please upload them in the previous cell.")

# Load documents with source metadata
docs = []
for path in EXPECTED_FILES:
    loader = TextLoader(path, encoding="utf-8")
    for d in loader.load():
        # Keep filename as "source" for UI
        d.metadata["source"] = os.path.basename(path)
        docs.append(d)

# Chunk
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
splits = splitter.split_documents(docs)

# Embeddings (local, no API key needed)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

# Vector store
vectorstore = FAISS.from_documents(splits, embeddings)

# Convenience retriever
def retrieve(query: str, k: int = 3):
    """Return top-k (doc, score) using FAISS similarity_search_with_score (lower score = closer)."""
    results = vectorstore.similarity_search_with_score(query, k=k)
    return results

print(f"Loaded {len(docs)} docs | {len(splits)} chunks indexed.")


In [None]:
# Step 4: LLM using Hugging Face Flan-T5-Base
!pip -q install transformers accelerate

from langchain_community.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

# Load Flan-T5-base
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Build pipeline
hf_pipeline = pipeline(
    "text2text-generation",
    model=model,
    tokenizer=tokenizer,
    max_length=512,
    do_sample=False  # deterministic
)

# Wrap for LangChain
llm = HuggingFacePipeline(pipeline=hf_pipeline)

print("✅ Flan-T5-base ready as LLM")


In [None]:
# ============================================
# Generic, Domain-Agnostic RAG (no special rules)
# - Better chunking
# - Hybrid retrieval: Embeddings (FAISS, MMR) + BM25
# - Optional Cross-Encoder re-ranking (fallback-safe)
# - Strict grounded answer chain with GK fallback
# - Utility: retrieve() returns (Document, pseudo_distance) for your UI
# ============================================

# 0) Installs (Colab safe). Re-run if kernel restarts.
!pip -q install langchain langchain-community sentence-transformers rank-bm25 faiss-cpu

# 1) Imports
import os, glob, math, numpy as np
from typing import List, Tuple
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever

# Optional cross-encoder re-ranker (falls back automatically if not available)
try:
    from sentence_transformers import CrossEncoder
    _CROSS_ENCODER = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
except Exception:
    _CROSS_ENCODER = None

# 2) Load your uploaded files (adjust paths if needed)
def _collect_files():
    paths = []
    # Common Colab paths: /content, and your notebook-mounted /mnt/data
    for root in ["/mnt/data", "/content"]:
        if os.path.isdir(root):
            paths += glob.glob(os.path.join(root, "*.txt"))
            paths += glob.glob(os.path.join(root, "*.md"))
    # De-dup while preserving order
    seen, ordered = set(), []
    for p in paths:
        if p not in seen:
            seen.add(p)
            ordered.append(p)
    return ordered

FILE_PATHS = _collect_files()
assert FILE_PATHS, "No .txt/.md files found in /mnt/data or /content. Upload your files first."

raw_docs: List[Document] = []
for p in FILE_PATHS:
    try:
        d = TextLoader(p, encoding="utf-8").load()
        # Attach a short source name for cleaner UI
        for doc in d:
            doc.metadata["source"] = os.path.basename(p)
        raw_docs.extend(d)
    except Exception as e:
        print(f"Skipping {p}: {e}")

# 3) Chunking (generic, keeps sections and bullets together)
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=800,       # 600–1000 works well
    chunk_overlap=120,    # some overlap to keep headings + bullets together
    separators=["\n\n", "\n", " ", ""]
)
split_docs: List[Document] = text_splitter.split_documents(raw_docs)

# Keep a global copy if your other code needs it
GLOBAL_DOCS = split_docs

# 4) Embeddings + FAISS (MMR retriever)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = FAISS.from_documents(split_docs, embeddings)

mmr_retriever = vectorstore.as_retriever(
    search_type="mmr",            # Maximal Marginal Relevance to reduce redundancy
    search_kwargs={
        "k": 12,                  # retrieve more, we'll re-rank down
        "fetch_k": 50,            # candidates pool size
        "lambda_mult": 0.4        # trade-off diversity vs similarity (0..1)
    }
)

# 5) Sparse BM25 retriever
bm25_retriever = BM25Retriever.from_documents(split_docs)
bm25_retriever.k = 20             # cast a wider net; will re-rank later

# 6) Ensemble (dense + sparse) via Reciprocal Rank Fusion under the hood
ensemble = EnsembleRetriever(
    retrievers=[mmr_retriever, bm25_retriever],
    weights=[0.5, 0.5]            # balanced; adjust if you prefer denser or sparser bias
)

# 7) Optional cross-encoder re-ranking (query, docs) -> top_k
def rerank_with_cross_encoder(query: str, docs: List[Document], top_k: int = 6) -> List[Document]:
    if _CROSS_ENCODER is None or not docs:
        return docs[:top_k]
    pairs = [(query, d.page_content) for d in docs]
    scores = _CROSS_ENCODER.predict(pairs)
    ranked = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
    return [d for d, _ in ranked[:top_k]]

# 8) A retrieve() helper compatible with your UI (returns (Document, pseudo_distance))
def retrieve(query: str, k: int = 6) -> List[Tuple[Document, float]]:
    # 8.1 get a fused list
    candidates = ensemble.get_relevant_documents(query)
    # 8.2 cross-encoder re-rank (optional) and truncate to final k
    final_docs = rerank_with_cross_encoder(query, candidates, top_k=k)
    if not final_docs:
        return []
    # 8.3 produce a pseudo-distance for your UI (invert normalized ranks)
    # higher rank -> lower distance
    n = len(final_docs)
    results = []
    for i, d in enumerate(final_docs):
        norm_rank = i / max(1, n - 1) if n > 1 else 0.0
        pseudo_distance = 1.0 - (1.0 - norm_rank)  # equals norm_rank
        results.append((d, float(pseudo_distance)))
    return results

# 9) Utility for your Gradio UI (snippets display)
def build_context_snippets(retrieved: List[Tuple[Document, float]]) -> List[str]:
    snippets = []
    for (doc, dist) in retrieved:
        src = doc.metadata.get("source", "unknown.txt")
        preview = doc.page_content.strip().replace("\n", " ")
        if len(preview) > 280:
            preview = preview[:280].rstrip() + "..."
        snippets.append(f"[{src}] (score={dist:.4f}): {preview}")
    return snippets

# 10) Strict grounded answer chain with generic fallback (no hand-crafted logic)
# NOTE: This assumes you already initialized `llm` elsewhere (GPT4All or Flan-T5).
STRICT_RAG_PROMPT = """You are a grounded assistant. Use ONLY the context to answer.
- If the answer is in the context, answer concisely.
- If not present, say: "Not found in the uploaded files."
Question: {question}

Context:
{context}

Answer:"""

GK_FALLBACK_PROMPT = """Answer the question using general knowledge. Be factual and concise.

Question: {question}

Answer:"""

def rag_answer(question: str, k: int = 6):
    # 10.1 Retrieve
    hits = retrieve(question, k=k)

    # 10.2 Build context and snippets
    snippets = build_context_snippets(hits)
    context = "\n\n".join([doc.page_content for doc, _ in hits]) if hits else ""

    # 10.3 If no LLM is configured, just return context
    if 'llm' not in globals() or llm is None:
        if context:
            return "RAG (GPT4All)", snippets, "LLM not configured. Retrieved context shown above."
        else:
            return "RAG (GPT4All)", [], "LLM not configured and no context retrieved."

    # 10.4 Ask grounded first
    grounded = llm.invoke(STRICT_RAG_PROMPT.format(question=question, context=context))
    grounded_clean = (grounded or "").strip()

    # If the grounded answer claims no info or is empty → GK fallback (your requested behavior)
    if (not grounded_clean) or ("not found in the uploaded files" in grounded_clean.lower()):
        gk = llm.invoke(GK_FALLBACK_PROMPT.format(question=question))
        gk_clean = (gk or "").strip()
        # Clearly mark as fallback
        final_answer = gk_clean if gk_clean else "No answer produced."
        branch = "RAG (GPT4All) + General Knowledge Fallback"
    else:
        final_answer = grounded_clean
        branch = "RAG (GPT4All)"

    return branch, snippets, final_answer

# ============================
# Sanity tests you can try (outside of Gradio)
# ============================
# print(rag_answer("Explain the negative feedback of Amazon Echo Dot")[2])
# print(rag_answer("What are the colors available for Echo Dot?")[2])
# print(rag_answer("What is Amazon’s return policy?")[2])


In [None]:
# 6) Dictionary Agent — Free Dictionary API
import requests
import urllib.parse

FREE_DICT_BASE = "https://api.dictionaryapi.dev/api/v2/entries/en/"

def dictionary_lookup_single(term: str) -> str:
    """
    Lookup one word/phrase via Free Dictionary API (dictionaryapi.dev).
    Returns a concise definition string or a fallback message.
    """
    clean = term.strip()
    if not clean:
        return "Empty term."

    url = FREE_DICT_BASE + urllib.parse.quote(clean)
    try:
        r = requests.get(url, timeout=8)
        if r.status_code != 200:
            return f"No entry found for '{term}'."

        data = r.json()
        # Expected shape: list[ entry ], pick first entry’s first available definition(s)
        defs = []
        if isinstance(data, list) and data:
            entry = data[0]
            for m in entry.get("meanings", []):
                for d in m.get("definitions", []):
                    text = (d.get("definition") or "").strip()
                    if text:
                        defs.append(text)
                if defs:
                    break  # keep it concise: stop after first meaning that has defs
        if not defs:
            return f"No entry found for '{term}'."

        # Keep it concise (up to 2 short definitions)
        preview = " | ".join(defs[:2])
        return f"{term}: {preview}"

    except Exception as e:
        return f"Error fetching definition for '{term}': {e}"

def dictionary_agent(query: str):
    """
    Supports:
      - 'define echo dot'             -> single phrase 'echo dot'
      - 'define echo, dot, policy'    -> multiple comma-separated terms
      - 'meaning of echo dot'         -> single phrase
    """
    q = query.strip().lower()

    # Extract the raw term(s) after 'define' or 'meaning of'
    terms_raw = q
    if q.startswith("define "):
        terms_raw = q[len("define "):].strip()
    elif "meaning of " in q:
        terms_raw = q.split("meaning of ", 1)[1].strip()
    elif q.startswith("define:"):
        terms_raw = q[len("define:"):].strip()

    # If commas present → multiple words; otherwise treat the entire remainder as ONE phrase
    if "," in terms_raw:
        terms = [t.strip() for t in terms_raw.split(",") if t.strip()]
    else:
        terms = [terms_raw.strip()] if terms_raw else []

    if not terms:
        return "Dictionary Agent", ["No document context required"], "Please provide a word or phrase to define."

    results = [dictionary_lookup_single(t) for t in terms]

    branch = "Dictionary Agent"
    snippets = ["Definitions fetched via Free Dictionary API (dictionaryapi.dev)"]
    final_answer = "\n".join(results)
    return branch, snippets, final_answer


In [None]:
 #Calculator Agent

import re
import numexpr as ne

MONTHS = ["january","february","march","april","may","june",
          "july","august","september","october","november","december"]
STOPWORDS = {
    "calculate","total","combined","units","unit","sold","sales","sale","of","the","a","an","for","to","in","on","by","with",
    "percentage","percent","increase","decrease","growth","from","between","and","till","until","upto","up to","average","mean",
    "highest","lowest","max","min","which","month","has","have","does","is"
}

# Accept digits, decimal point, arithmetic ops, parentheses, and spaces
# We'll normalize × x ÷ – — to * / - respectively before matching.
ARITH_ALLOWED = re.compile(r"^[0-9\.\+\-\*/\(\)\s]+$")

# Find candidate arithmetic spans inside a longer query (choose the longest)
ARITH_SPAN = re.compile(r"(?:\(?\s*[-+]?\d+(?:\.\d+)?\s*\)?\s*(?:[\+\-\*/]\s*\(?\s*[-+]?\d+(?:\.\d+)?\s*\)?\s*)+)")


def _normalize_expr(expr: str) -> str:
    # normalize common symbols to standard ops
    expr = expr.replace("×", "*").replace("x", "*").replace("X", "*").replace("·", "*")
    expr = expr.replace("÷", "/")
    expr = expr.replace("–", "-").replace("—", "-")
    # remove thousands separators
    expr = expr.replace(",", " ")
    # collapse multiple spaces
    expr = re.sub(r"\s+", " ", expr).strip()
    return expr


def _is_safe_expr(expr: str) -> bool:
    return bool(ARITH_ALLOWED.match(expr))


def _extract_inline_expr(query: str):
    q = _normalize_expr(query)
    # collect all arithmetic-looking spans and choose the longest (most likely full expression)
    spans = ARITH_SPAN.findall(q)
    if not spans:
        return None
    # choose the longest by length
    expr = max((s.strip() for s in spans), key=len)
    expr = _normalize_expr(expr)
    if not _is_safe_expr(expr):
        return None
    # Evaluate with BODMAS precedence via numexpr
    try:
        val = ne.evaluate(expr).item()
        return expr, val
    except Exception:
        return None


def _product_phrases_and_tokens(query: str):
    q = query.lower()
    # split on commas and " and "
    chunks = re.split(r"\s*,\s*|\s+and\s+", q)
    phrases = []
    for ch in chunks:
        words = [w for w in re.findall(r"[a-zA-Z]+", ch) if w not in STOPWORDS]
        if words:
            phrases.append(" ".join(words))
    flat_tokens = [w for w in re.findall(r"[a-zA-Z]+", q) if w not in STOPWORDS]
    return [p for p in phrases if p], flat_tokens


def _split_product_blocks(text: str):
    # each block starts with "Product:"
    return [b.strip() for b in re.split(r"(?=^Product:\s*)", text, flags=re.IGNORECASE | re.MULTILINE) if b.strip()]


def _block_matches(block: str, phrase: str, flat_tokens: list):
    b = block.lower()
    if phrase and phrase in b:
        return True
    return all(t in b for t in flat_tokens) if flat_tokens else False


def _month_values_from_block(block: str):
    """
    Parse lines like "January: 150 units sold" -> {"january": 150, ...}
    """
    data = {}
    for line in block.splitlines():
        m = re.search(r"([A-Za-z]+):\s*(\d+)\s+units\s+sold", line)
        if m:
            mon = m.group(1).lower()
            val = int(m.group(2))
            if mon in MONTHS:
                data[mon] = val
    return data


def _pick_two_months_from_query(q_lower: str):
    found = [m for m in MONTHS if m in q_lower]
    if len(found) >= 2:
        return found[0], found[-1]    # first and last mention
    return None, None


def calculator_agent(query: str, k: int = 5, distance_threshold: float = 1.5):
    try:
        # ---------- NEW: Direct arithmetic (BODMAS) path FIRST ----------
        inline = _extract_inline_expr(query)
        if inline:
            expr, val = inline
            # No need for doc context; we evaluated a self-contained expression
            return "Calculator Agent", ["Direct expression (BODMAS) — no document context required"], f"Expression: {expr}\nResult: {val}"

        # ---------- Existing doc-aware logic below ----------
        retrieved = retrieve(query, k=k)
        relevant = [r for r in retrieved if r[1] <= distance_threshold]
        snippets = build_context_snippets(retrieved if relevant else [])

        # Build context blob and split into product blocks
        context = "\n\n".join([r[0].page_content for r in relevant]) if relevant else ""
        blocks = _split_product_blocks(context) if context else []

        # Identify product blocks from query (multi-word supported)
        phrases, flat_tokens = _product_phrases_and_tokens(query)
        matched_blocks = []
        for ph in phrases:
            matched_blocks += [b for b in blocks if _block_matches(b, ph, [])]
        if not matched_blocks and flat_tokens:
            matched_blocks = [b for b in blocks if _block_matches(b, "", flat_tokens)]
        if not matched_blocks and blocks:
            # last resort: pick most relevant block (first)
            matched_blocks = [blocks[0]]

        ql = query.lower()

        # ---------------- Deterministic extrema answers (NO LLM) ----------------
        if any(w in ql for w in ["highest","max","peak"]) and matched_blocks:
            data = _month_values_from_block(matched_blocks[0])
            if data:
                best_month = max(data, key=data.get)
                return "Calculator Agent", [matched_blocks[0]], f"Result: {best_month.title()} with {data[best_month]} units"

        if any(w in ql for w in ["lowest","min"]) and matched_blocks:
            data = _month_values_from_block(matched_blocks[0])
            if data:
                worst_month = min(data, key=data.get)
                return "Calculator Agent", [matched_blocks[0]], f"Result: {worst_month.title()} with {data[worst_month]} units"

        # ---------------- Percentage increase/decrease ----------------
        if any(w in ql for w in ["percentage","percent","increase","decrease","growth"]) and matched_blocks:
            data = _month_values_from_block(matched_blocks[0])
            m1, m2 = _pick_two_months_from_query(ql)
            if m1 and m2 and m1 in data and m2 in data:
                v1, v2 = data[m1], data[m2]
                expr = f"({v2} - {v1}) / {v1} * 100"
                result = (v2 - v1) / v1 * 100 if v1 != 0 else float('nan')
                return "Calculator Agent", [matched_blocks[0]], f"Expression: {expr}\nResult: {result:.2f}%"

        # ---------------- Totals / combined totals ----------------
        if any(w in ql for w in ["total","sum","combined"]) and matched_blocks:
            expr_parts, totals, used = [], [], []
            for b in matched_blocks:
                data = _month_values_from_block(b)
                nums = list(data.values())
                if nums:
                    expr_parts.append(" + ".join(map(str, nums)))
                    totals.append(sum(nums))
                    used.append(b)
            if totals:
                if len(totals) > 1:
                    expr = " + ".join(f"({p})" for p in expr_parts)
                    res = sum(totals)
                else:
                    expr = expr_parts[0]
                    res = totals[0]
                return "Calculator Agent", used, f"Expression: {expr}\nResult: {res}"

        # ---------------- Average (mean) ----------------
        if any(w in ql for w in ["average","mean","avg"]) and matched_blocks:
            data = _month_values_from_block(matched_blocks[0])
            nums = list(data.values())
            if nums:
                expr = f"({'+'.join(map(str, nums))}) / {len(nums)}"
                res = sum(nums)/len(nums)
                return "Calculator Agent", [matched_blocks[0]], f"Expression: {expr}\nResult: {res:.2f}"

        # ---------------- Difference (absolute) if two months given ----------------
        if any(w in ql for w in ["difference","more","less"]) and matched_blocks:
            data = _month_values_from_block(matched_blocks[0])
            m1, m2 = _pick_two_months_from_query(ql)
            if m1 and m2 and m1 in data and m2 in data:
                v1, v2 = data[m1], data[m2]
                expr = f"{v2} - {v1}"
                res  = v2 - v1
                return "Calculator Agent", [matched_blocks[0]], f"Expression: {expr}\nResult: {res}"

        # ---------------- Ratios between two products ----------------
        if any(w in ql for w in ["ratio","times","compared","compare"]) and len(matched_blocks) >= 2:
            d1 = _month_values_from_block(matched_blocks[0])
            d2 = _month_values_from_block(matched_blocks[1])
            if d1 and d2:
                s1, s2 = sum(d1.values()), sum(d2.values())
                expr = f"{s1} / {s2}"
                res  = s1 / s2 if s2 != 0 else float('inf')
                return "Calculator Agent", matched_blocks[:2], f"Expression: {expr}\nResult: {res:.2f}"

        # ---------------- Default deterministic sum of first block (as a safe fallback) ----------------
        if matched_blocks:
            data = _month_values_from_block(matched_blocks[0])
            nums = list(data.values())
            if nums:
                expr = " + ".join(map(str, nums))
                res  = sum(nums)
                return "Calculator Agent", [matched_blocks[0]], f"Expression: {expr}\nResult: {res}"

        # ---------------- Safe LLM fallback (only if llm & prompt exist) ----------------
        if relevant and 'llm' in globals() and llm:
            local_prompt = (
                MATH_PROMPT if 'MATH_PROMPT' in globals() and MATH_PROMPT
                else "Produce ONE arithmetic expression (numbers and + - * / only) that answers the question.\nNo words.\nQuestion: {question}\nExpression:"
            )
            ctx = "\n\n".join([r[0].page_content for r in relevant])
            raw = llm.invoke(local_prompt.format(question=query, context=ctx))
            m = re.search(r"([0-9\.\(\)\+\-\*/\s]+)", (raw or ""))
            expr = (m.group(1) if m else "").replace(",", "").strip()
            if expr:
                try:
                    result = ne.evaluate(expr).item()
                    return "Calculator Agent", snippets if snippets else ["No document context found"], f"Expression: {expr}\nResult: {result}"
                except Exception:
                    pass

        # ---------------- Final fallback (no crash) ----------------
        return "Calculator Agent", snippets if snippets else ["No document context found"], \
               "Could not compute a result from the available information."

    except Exception as e:
        # Never let exceptions bubble to Gradio as 'Error'
        return "Calculator Agent", ["Error while computing"], f"Error: {e}"


In [None]:
# ========= Drop-in patch: stop echo + ensure grounded answer =========

# 1) Prompts
RAG_TEMPLATE = """### Instruction:
Answer ONLY using the provided context. If the answer is not in the context, reply exactly:
Not found in the uploaded files.

### Context:
{context}

### Question:
{question}

### Answer:
"""

RAG_TEMPLATE_COMPACT = (
    # Minimal template for retry (works better for small models)
    "Use ONLY the context to answer. If not in context, reply exactly: Not found in the uploaded files.\n\n"
    "Context:\n{context}\n\nQuestion: {question}\nAnswer:"
)

GK_FALLBACK_TEMPLATE = (
    "Answer the question using general knowledge. Be factual and concise.\n\n"
    "Question:\n{question}\n\nAnswer:"
)

# 2) Stop tokens (prevents the model from drifting back into headers)
STOP_TOKENS = ["\n###", "### Question:", "### Context:", "### Instruction:", "\nQuestion:", "\nContext:"]

# 3) Safer LLM call that prefers `.generate(..., stop=...)`, else falls back to `.invoke`
def _llm_call(prompt: str, stops: list[str] = None, max_retries: int = 1) -> str:
    if 'llm' not in globals() or llm is None:
        return ""
    stops = stops or STOP_TOKENS

    # Prefer generate to pass stop tokens
    if hasattr(llm, "generate"):
        try:
            out = llm.generate([prompt], stop=stops)  # stop works for many LangChain LLMs
            txt = out.generations[0][0].text if out and out.generations and out.generations[0] else ""
            if txt:
                return txt.strip()
        except Exception:
            pass

    # Fallback to invoke (no stop support in many wrappers)
    txt = ""
    for _ in range(max_retries):
        try:
            txt = llm.invoke(prompt)
            if txt:
                return txt.strip()
        except Exception:
            continue
    return (txt or "").strip()

# 4) Build concise context to reduce prompt echo
def _build_context_text(hits, max_chars: int = 2000) -> str:
    if not hits:
        return ""
    parts, total = [], 0
    for doc, _ in hits:
        chunk = (doc.page_content or "").strip()
        if not chunk:
            continue
        add = chunk[: max(0, max_chars - total)]
        if not add:
            break
        parts.append(add)
        total += len(add)
        if total >= max_chars:
            break
    return "\n\n".join(parts)

def _looks_like_echo(txt: str) -> bool:
    if not txt:
        return True
    tl = txt.lower()
    return (
        "### instruction" in tl
        or "### context" in tl
        or "### question" in tl
        or "answer only using the provided context" in tl
        or len(tl.strip()) <= 10
    )

def rag_answer(question: str, k: int = 6):
    # Retrieve as usual
    hits = retrieve(question, k=k)
    snippets = build_context_snippets(hits)
    context = _build_context_text(hits, max_chars=2000)

    # If no context at all → GK fallback immediately
    if not context:
        gk = _llm_call(GK_FALLBACK_TEMPLATE.format(question=question), STOP_TOKENS)
        return "RAG (FLAN-T5)", snippets, (gk or "No answer produced.")  # or keep your old label

    # First attempt: full grounded prompt + stop tokens
    grounded_1 = _llm_call(RAG_TEMPLATE.format(question=question, context=context), STOP_TOKENS)
    grounded_1 = (grounded_1 or "").strip()
    if grounded_1 and (not _looks_like_echo(grounded_1)) and ("not found in the uploaded files" not in grounded_1.lower()):
        return "RAG (FLAN-T5)", snippets, grounded_1  # label text is up to you

    # Retry once with a compact prompt and only the top-1 chunk (strongly reduces echo)
    top1_ctx = _build_context_text(hits[:1], max_chars=1000)
    grounded_2 = _llm_call(RAG_TEMPLATE_COMPACT.format(question=question, context=top1_ctx), STOP_TOKENS)
    grounded_2 = (grounded_2 or "").strip()
    if grounded_2 and (not _looks_like_echo(grounded_2)) and ("not found in the uploaded files" not in grounded_2.lower()):
        return "RAG (FLAN-T5)", snippets, grounded_2

    # Final fallback: GK (you asked to allow general knowledge if not found)
    gk = _llm_call(GK_FALLBACK_TEMPLATE.format(question=question), STOP_TOKENS)
    gk = (gk or "").strip() or "No answer produced."
    return "RAG (FLAN-T5) + General Knowledge Fallback", snippets, gk


In [None]:
# --- Smarter Router: numeric vs. reviews/sentiment vs. dictionary vs. RAG ---
import re

# Strong calc cues (not generic words like "sales")
CALC_STRONG = {
    "calculate","sum","total","combined","add","plus",
    "average","avg","mean","median",
    "difference","delta","gap","increase","decrease","growth","change",
    "percent","percentage","%","ratio","times","compared","compare",
    "highest","lowest","max","min","peak",
    "ytd","cumulative"
}
MONTHS = {"january","february","march","april","may","june","july",
          "august","september","october","november","december"}

# Reviews / sentiment / pros & cons → RAG
REVIEW_KEYWORDS = {
    "review","reviews","feedback","negative","positive","complaint","complaints",
    "issue","issues","pros","cons","drawbacks","limitations","problems","impression","sentiment"
}

INLINE_MATH_RE = re.compile(r"(?:\d+[\d\.,]*|\.\d+)(?:\s*[\+\-\*/]\s*(?:\d+[\d\.,]*|\.\d+))+")
ANY_DIGIT_RE   = re.compile(r"\d")

def _has_inline_math(q: str) -> bool:
    return INLINE_MATH_RE.search(q.replace(",", "")) is not None

def _looks_numeric(q: str) -> bool:
    ql = q.lower()
    # must have: inline math OR a strong calc cue OR month name OR an actual digit
    return (
        _has_inline_math(ql)
        or any(k in ql for k in CALC_STRONG)
        or any(m in ql for m in MONTHS)
        or bool(ANY_DIGIT_RE.search(ql))
    )

def _looks_review(q: str) -> bool:
    ql = q.lower()
    return any(k in ql for k in REVIEW_KEYWORDS)

def choose_branch(query: str) -> str:
    q = query.strip().lower()

    # Dictionary agent
    if q.startswith("define ") or "meaning of " in q or "define:" in q:
        return "dictionary"

    # Reviews / sentiment → RAG
    if _looks_review(q):
        return "rag"

    # Numeric reasoning → Calculator
    if _looks_numeric(q):
        return "calculator"

    # Default → RAG
    return "rag"

# --- Main dispatcher with second-chance override for numeric queries ---
def answer_query(query: str):
    initial = choose_branch(query)

    try:
        if initial == "calculator":
            b, snippets, final = calculator_agent(query)
            return b, "\n\n".join(snippets), final

        if initial == "dictionary":
            b, snippets, final = dictionary_agent(query)
            return b, "\n\n".join(snippets), final

        # initial == "rag"
        # If the query also looks numeric, prefer Calculator (safety net)
        if _looks_numeric(query):
            b, snippets, final = calculator_agent(query)
            return b, "\n\n".join(snippets), final

        b, snippets, final = rag_answer(query)
        return b, "\n\n".join(snippets), final

    except Exception as e:
        return "Error", f"(debug) {type(e).__name__}: {e}", "An error occurred while processing your query."


In [None]:
# 9) Gradio UI with timeout
import gradio as gr
import concurrent.futures

TIMEOUT_SECS = 150  # 2 minutes

def gradio_fn(user_query):
    try:
        # Run answer_query with timeout
        with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
            future = executor.submit(answer_query, user_query)
            try:
                branch, snippets, final = future.result(timeout=TIMEOUT_SECS)
                return branch, snippets, final
            except concurrent.futures.TimeoutError:
                return "Error", "Timed out", "Answer not found (took more than 2 minutes)."
    except Exception as e:
        return "Error", f"(debug) {type(e).__name__}: {e}", "An error occurred while processing your query."

with gr.Blocks(title="RAG-Powered Multi-Agent Q&A") as demo:
    gr.Markdown("## RAG-Powered Multi-Agent Q&A (LangChain)")

    inp = gr.Textbox(label="Ask a question…", placeholder="Type your query here...")
    with gr.Row():
        out_branch = gr.Textbox(label="Which tool/agent branch was used")
    out_snippets = gr.Textbox(label="The retrieved context snippets", lines=10)
    out_final = gr.Textbox(label="The final answer", lines=8)

    btn = gr.Button("Run")
    btn.click(gradio_fn, inputs=inp, outputs=[out_branch, out_snippets, out_final])

demo.launch(share=True)
