In [2]:
import os
from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple

import numpy as np
import chromadb
from chromadb.config import Settings
from langchain_openai import OpenAIEmbeddings  # updated import
from langchain.text_splitter import RecursiveCharacterTextSplitter
import google.generativeai as genai
import PyPDF2
from docx import Document
from pptx import Presentation

import gradio as gr

# ---------------------- Agent types ----------------------
from enum import Enum
from dataclasses import dataclass

class AgentType(Enum):
    RESEARCHER = "researcher"
    SYNTHESIZER = "synthesizer"
    FACT_CHECKER = "fact_checker"
    FOLLOW_UP = "follow_up"
    COORDINATOR = "coordinator"

@dataclass
class AgentResponse:
    agent_type: AgentType
    content: str
    confidence: float
    sources: List[str]
    reasoning: str
    follow_up_needed: bool = False

# ---------------------- Vector DB wrapper ----------------------
class VectorDB:
    def __init__(self, name: str, metadata: dict = None):
        self.name = name
        self.metadata = metadata or {}

    def add_documents(self, documents: List[str], metadatas: List[dict], ids: List[str], embeddings: List[List[float]]):
        raise NotImplementedError()

    def query(self, embedding: List[float], n_results: int = 5):
        raise NotImplementedError()

    def get_stats(self) -> dict:
        raise NotImplementedError()

# ---------------------- Chroma Vector DB ----------------------
class ChromaVectorDB(VectorDB):
    def __init__(self, name: str, persist_path: str, metadata: dict = None, anonymized_telemetry: bool = False):
        super().__init__(name, metadata)
        os.makedirs(persist_path, exist_ok=True)
        self.client = chromadb.PersistentClient(path=persist_path, settings=Settings(anonymized_telemetry=anonymized_telemetry))
        self.collection = self.client.get_or_create_collection(name=name, metadata=metadata or {})
        self.persist_path = persist_path

    def add_documents(self, documents: List[str], metadatas: List[dict], ids: List[str], embeddings: List[List[float]]):
        self.collection.add(
            embeddings=embeddings,
            documents=documents,
            metadatas=metadatas,
            ids=ids
        )

    def query(self, embedding: List[float], n_results: int = 5):
        resp = self.collection.query(query_embeddings=[embedding], n_results=n_results)
        return {
            "documents": resp.get("documents", [[]])[0] if isinstance(resp.get("documents", None), list) else resp.get("documents", []),
            "metadatas": resp.get("metadatas", [[]])[0] if isinstance(resp.get("metadatas", None), list) else resp.get("metadatas", []),
            "ids": resp.get("ids", [[]])[0] if isinstance(resp.get("ids", None), list) else resp.get("ids", []),
            "distances": resp.get("distances", [[]])[0] if isinstance(resp.get("distances", None), list) else resp.get("distances", [])
        }

    def get_stats(self) -> dict:
        try:
            count = self.collection.count()
        except Exception:
            count = None
        return {"name": self.name, "persist_path": self.persist_path, "count": count}

# ---------------------- VectorDB Registry ----------------------
class VectorDBRegistry:
    def __init__(self, embeddings: OpenAIEmbeddings):
        self.embeddings = embeddings
        self.dbs: Dict[str, VectorDB] = {}
        self.db_topic_vectors: Dict[str, np.ndarray] = {}
        self.db_topic_keywords: Dict[str, List[str]] = {}

    def register_db(self, db: VectorDB, domain_keywords: List[str]):
        if db.name in self.dbs:
            raise ValueError(f"DB name '{db.name}' already registered.")
        self.dbs[db.name] = db
        self.db_topic_keywords[db.name] = domain_keywords.copy()
        kw_embeddings = [self.embeddings.embed_query(k) for k in domain_keywords]
        avg = np.mean(np.array(kw_embeddings, dtype=np.float32), axis=0)
        self.db_topic_vectors[db.name] = avg / np.linalg.norm(avg)

    def choose_dbs(self, query: str, top_k: int = 1) -> List[Tuple[str, float]]:
        q_emb = np.array(self.embeddings.embed_query(query), dtype=np.float32)
        q_emb_norm = q_emb / np.linalg.norm(q_emb)
        sims = []
        for name, topic_vec in self.db_topic_vectors.items():
            sim = float(np.dot(q_emb_norm, topic_vec))
            sims.append((name, sim))
        sims_sorted = sorted(sims, key=lambda x: x[1], reverse=True)
        return sims_sorted[:top_k]

    def get_db(self, name: str) -> VectorDB:
        return self.dbs[name]

# ---------------------- Agentic RAG ----------------------
class AgenticRAGMultiDB:
    def __init__(self, gemini_api_key: str, openai_api_key: str):
        genai.configure(api_key=gemini_api_key)
        self.gemini_model = genai.GenerativeModel('gemini-1.5-flash')
        self.embeddings = OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key=openai_api_key)
        self.registry = VectorDBRegistry(self.embeddings)
        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=150)
        self.session_memory: Dict[str, List[Dict[str, Any]]] = {}

    def register_chroma_db(self, name: str, persist_path: str, domain_keywords: List[str]):
        metadata = {"created_by": "AgenticRAG", "db_name": name}
        db = ChromaVectorDB(name=name, persist_path=persist_path, metadata=metadata)
        self.registry.register_db(db, domain_keywords)
        return f"✅ Registered Chroma DB '{name}' with keywords: {domain_keywords}"

    def ingest_to_db(self, file_path: str, file_name: str, target_db_name: Optional[str] = None, session_id: str = "default"):
        ext = file_name.lower().split('.')[-1]
        content = ""
        try:
            if ext == "pdf":
                with open(file_path, "rb") as f:
                    reader = PyPDF2.PdfReader(f)
                    pages = [p.extract_text() or "" for p in reader.pages]
                    content = "\n".join(pages)
            elif ext == "docx":
                doc = Document(file_path)
                content = "\n".join([p.text for p in doc.paragraphs])
            elif ext == "txt":
                with open(file_path, "r", encoding="utf-8") as f:
                    content = f.read()
            else:
                return f"❌ Unsupported extension: {ext}"
        except Exception as e:
            return f"❌ Error extracting {file_name}: {e}"

        if not target_db_name:
            chosen = self.registry.choose_dbs(f"{file_name} {content[:200]}", top_k=1)
            target_db_name = chosen[0][0] if chosen else list(self.registry.dbs.keys())[0]

        db = self.registry.get_db(target_db_name)
        chunks = self.text_splitter.split_text(content)
        embeddings = [self.embeddings.embed_query(c) for c in chunks]
        ids = [f"{file_name}_{i}_{datetime.now().timestamp()}" for i in range(len(chunks))]
        metadatas = [{"file_name": file_name, "file_type": ext, "chunk_index": i, "session_id": session_id} for i in range(len(chunks))]
        db.add_documents(chunks, metadatas, ids, embeddings)
        return f"✅ Ingested {file_name} into DB '{target_db_name}' ({len(chunks)} chunks)."

    def retrieve_for_query(self, query: str, top_dbs: int = 1, n_results_per_db: int = 3):
        chosen = self.registry.choose_dbs(query, top_k=top_dbs)
        docs_collected, sources = [], []
        q_emb = self.embeddings.embed_query(query)
        for db_name, sim in chosen:
            db = self.registry.get_db(db_name)
            results = db.query(q_emb, n_results=n_results_per_db)
            retrieved_docs = results.get("documents", [])
            retrieved_meta = results.get("metadatas", [])
            for i, d in enumerate(retrieved_docs):
                docs_collected.append(d)
                meta = retrieved_meta[i] if i < len(retrieved_meta) else {}
                sources.append(f"{db_name}:{meta.get('file_name', 'unknown')}")
        return docs_collected, sources

# ---------------------- Gradio App ----------------------
rag_instance = None
uploaded_files: List[str] = []

def init_rag(gemini_key, openai_key):
    global rag_instance, uploaded_files
    rag_instance = AgenticRAGMultiDB(gemini_key, openai_key)
    uploaded_files = []
    # Example DBs
    rag_instance.register_chroma_db("legal_db", "./chroma_legal", ["contracts", "agreements", "law"])
    rag_instance.register_chroma_db("finance_db", "./chroma_finance", ["finance", "revenue", "profit"])
    return "✅ Agentic RAG initialized with 2 DBs."

def upload_document(files):
    global uploaded_files
    uploaded_files = files  # store paths
    messages = [f"✅ Uploaded {os.path.basename(f)}" for f in files]
    return "\n".join(messages)

import pandas as pd

# Update ingest_documents function
def ingest_documents():
    ingestion_results = []
    for file_path in uploaded_files:
        file_name = os.path.basename(file_path)
        msg = rag_instance.ingest_to_db(file_path, file_name)
        # Extract DB name and chunk count from message
        if "Ingested" in msg:
            try:
                db_name = msg.split("into DB '")[1].split("'")[0]
                chunk_count = int(msg.split("(")[1].split()[0])
            except Exception:
                db_name = "Unknown"
                chunk_count = 0
        else:
            db_name = "Error"
            chunk_count = 0
        ingestion_results.append({
            "File Name": file_name,
            "Target DB": db_name,
            "Chunks": chunk_count,
            "Status": msg
        })
    df = pd.DataFrame(ingestion_results)
    return df


def ask_query(query):
    docs, sources = rag_instance.retrieve_for_query(query)

    if not docs:
        return "No documents found."

    # Concatenate chunks to feed to Gemini LLM
    context = "\n\n".join(docs[:5])  # limit to top 5 chunks to avoid token limits
    prompt = f"""
    You are an AI assistant. Use the following document excerpts to answer the user's question.
    Question: {query}
    Documents:
    {context}

    Answer concisely, using only the relevant information from the documents.
    """

    response = rag_instance.gemini_model.generate_content(prompt)
    answer_text = getattr(response, "text", None) or "No answer generated."

    dbs_used = set([s.split(":")[0] for s in sources])
    return f"Answer (excerpt):\n{answer_text}\n\nDB(s) used: {', '.join(dbs_used)}"


with gr.Blocks() as demo:
    gr.Markdown("# Agentic RAG Multi-DB Demo")
    with gr.Row():
        gemini_key = gr.Textbox(label="Gemini API Key", type="password")
        openai_key = gr.Textbox(label="OpenAI API Key", type="password")
        btn_init = gr.Button("Initialize Agentic RAG")
        output_init = gr.Textbox(label="Initialization Status")
        btn_init.click(init_rag, inputs=[gemini_key, openai_key], outputs=output_init)
    with gr.Row():
        file_input = gr.Files(file_types=[".pdf", ".docx", ".txt"], label="Upload Documents")
        btn_upload = gr.Button("Upload Documents")
        output_upload = gr.Textbox(label="Upload Status")
        btn_upload.click(upload_document, inputs=file_input, outputs=output_upload)
    with gr.Row():
        btn_ingest = gr.Button("Ingest Uploaded Documents")
        output_ingest = gr.DataFrame(headers=["File Name", "Target DB", "Chunks", "Status"])
        btn_ingest.click(ingest_documents, inputs=[], outputs=output_ingest)

    with gr.Row():
        query_input = gr.Textbox(label="Ask a question")
        btn_query = gr.Button("Query")
        output_query = gr.Textbox(label="Answer & DB used")
        btn_query.click(ask_query, inputs=query_input, outputs=output_query)

demo.launch()


* Running on local URL:  http://127.0.0.1:7861
* To create a public link, set `share=True` in `launch()`.


