# MasterHive Project

This project aims to develop a basis RAG system that prototypes a custom chat_with_your_doc business model.

To achive this, we need to develop this in a multiphase relationship
- Chuncking and Indexing
- Embedding and Storage
- Retrieval and Reranking
- Storage and Chat History Management


In [None]:
%pip install langchain langgraph langchain-google-genai langchain_community chromadb rank-bm25 transformers pypdf nest-asyncio pymongo

## A Simple RAG

In [None]:
from langchain import PromptTemplate
from langchain import hub
from langchain.docstore.document import Document
from langchain.document_loaders import WebBaseLoader
from langchain.schema import StrOutputParser
from langchain.schema.prompt_template import format_document
from langchain.schema.runnable import RunnablePassthrough
from langchain.vectorstores import Chroma

In [None]:
import os

os.environ["GOOGLE_API_KEY"] = "GOOGLE_API_KEY"

In [None]:
# Read and parse the website data
loader = WebBaseLoader("https://blog.google/technology/ai/google-gemini-ai/")
docs = loader.load()

In [None]:
# Extract the text from the website data document
text_content = docs[0].page_content

# The text content between the substrings "code, audio, image and video." to
# "Cloud TPU v5p" is relevant for this. You can use Python's `split()`
# to select the required content.
text_content_1 = text_content.split("code, audio, image and video.",1)[1]
final_text = text_content_1.split("Cloud TPU v5p",1)[0]

# Convert the text to LangChain's `Document` format
docs = [Document(page_content=final_text, metadata={"source": "local"})]

In [None]:
# Initialize Gemini's embedding model
from langchain_google_genai import GoogleGenerativeAIEmbeddings

gemini_embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")

In [None]:
# Save to disk: Store the data using Chroma
vectorstore = Chroma.from_documents(
                     documents=docs,                 # Data
                     embedding=gemini_embeddings,    # Embedding model
                     persist_directory="./chroma_db" # Directory to save data
                     )

In [None]:
# Create a retriever using Chroma
# Load from disk
vectorstore_disk = Chroma(
                        persist_directory="./chroma_db",     
                        embedding_function=gemini_embeddings  
                   )
# Get the Retriever interface for the store to use later.
# When an unstructured query is given to a retriever it will return documents.
#
# Since only 1 document is stored in the Chroma vector store, search_kwargs `k`
# is set to 1 to decrease the `k` value of chroma's similarity search from 4 to
# 1. If you don't pass this value, you will get a warning.
retriever = vectorstore_disk.as_retriever(search_kwargs={"k": 1})

# Check if the retriever is working by trying to fetch the relevant docs related
# to the word 'MMLU' (Massive Multitask Language Understanding). 
# If the length is greater than zero, it means that the retriever is functioning well.
print(len(retriever.get_relevant_documents("MMLU")))

In [None]:
# Initializa Gemini
from langchain_google_genai import ChatGoogleGenerativeAI

# To configure model parameters use the `generation_config` parameter.
# eg. generation_config = {"temperature": 0.7, "topP": 0.8, "topK": 40}
# If you only want to set a custom temperature for the model use the
# "temperature" parameter directly.

llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash")

In [None]:
# Prompt template to query Gemini
llm_prompt_template = """You are an assistant for question-answering tasks.
Use the following context to answer the question.
If you don't know the answer, just say that you don't know.
Use five sentences maximum and keep the answer concise.\n
Question: {question} \nContext: {context} \nAnswer:"""

llm_prompt = PromptTemplate.from_template(llm_prompt_template)

print(llm_prompt)

In [None]:
# Combine data from documents to readable string format.
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

# Create stuff documents chain using LCEL.
#
# This is called a chain because you are chaining together different elements
# with the LLM. In the following example, to create the stuff chain, you will
# combine the relevant context from the website data matching the question, the
# LLM model, and the output parser together like a chain using LCEL.
#
# The chain implements the following pipeline:
# 1. Extract the website data relevant to the question from the Chroma
#    vector store and save it to the variable `context`.
# 2. `RunnablePassthrough` option to provide `question` when invoking
#    the chain.
# 3. The `context` and `question` are then passed to the prompt where they
#    are populated in the respective variables.
# 4. This prompt is then passed to the LLM (`gemini-2.0-flash`).
# 5. Output from the LLM is passed through an output parser
#    to structure the model's response.
rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | llm_prompt
    | llm
    | StrOutputParser()
)

In [None]:
rag_chain.invoke("What is Gemini?")

In [None]:
# Putting it all together into a simple RAG

from langchain import PromptTemplate
from langchain import hub
from langchain.docstore.document import Document
from langchain.document_loaders import WebBaseLoader
from langchain.schema import StrOutputParser
from langchain.schema.prompt_template import format_document
from langchain.schema.runnable import RunnablePassthrough
from langchain.vectorstores import Chroma

import os

os.environ["GOOGLE_API_KEY"] = "API_Key"

# Read and parse the website data
loader = WebBaseLoader("https://blog.google/technology/ai/google-gemini-ai/")
docs = loader.load()

# Extract the text from the website data document
text_content = docs[0].page_content

# The text content between the substrings "code, audio, image and video." to
# "Cloud TPU v5p" is relevant for this tutorial. You can use Python's `split()`
# to select the required content.
text_content_1 = text_content.split("code, audio, image and video.",1)[1]
final_text = text_content_1.split("Cloud TPU v5p",1)[0]

# Convert the text to LangChain's `Document` format
docs = [Document(page_content=final_text, metadata={"source": "local"})]

# Initialize Gemini's embedding model
from langchain_google_genai import GoogleGenerativeAIEmbeddings
gemini_embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")

# Save to disk: Store the data using Chroma
vectorstore = Chroma.from_documents(
                     documents=docs,                 # Data
                     embedding=gemini_embeddings,    # Embedding model
                     persist_directory="./chroma_db" # Directory to save data
                     )

# Create a retriever using Chroma
# Load from disk
vectorstore_disk = Chroma(
                        persist_directory="./chroma_db",       # Directory of db
                        embedding_function=gemini_embeddings   # Embedding model
                   )

retriever = vectorstore_disk.as_retriever(search_kwargs={"k": 1})

# the retriever is functioning well.
print(len(retriever.get_relevant_documents("MMLU")))

# Initializa Gemini
from langchain_google_genai import ChatGoogleGenerativeAI
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash")


# Prompt template to query Gemini
llm_prompt_template = """You are an assistant for question-answering tasks.
Use the following context to answer the question.
If you don't know the answer, just say that you don't know.
Use five sentences maximum and keep the answer concise.\n
Question: {question} \nContext: {context} \nAnswer:"""

llm_prompt = PromptTemplate.from_template(llm_prompt_template)
print(llm_prompt)


# Combine data from documents to readable string format.
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | llm_prompt
    | llm
    | StrOutputParser()
)

rag_chain.invoke("What is Gemini?")

## Advanced RAG

### Introducing Langgraph and Reranking

In [None]:
import os
from langchain import PromptTemplate, hub
from langchain.document_loaders import PyPDFLoader, WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document, StrOutputParser
from langchain.vectorstores import Chroma
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langgraph.graph import END, StateGraph
from typing import Dict, List, TypedDict
import requests
from rank_bm25 import BM25Okapi  # For hybrid retrieval
from transformers import AutoTokenizer, AutoModelForSequenceClassification  # For reranking

os.environ["GOOGLE_API_KEY"] = "GOOGLE_API_KEY"


# 1. Improved Document Processing --------------------------------
def load_and_chunk_data(source: str, is_pdf: bool = False):
    """Load and split documents with overlap for context preservation"""
    if is_pdf:
        loader = PyPDFLoader(source)
    else:
        loader = WebBaseLoader(source)

    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=200,
        separators=["\n\n", "\n", " ", ""]
    )
    return loader.load_and_split(text_splitter)

# Process multiple documents (PDF example)
docs = load_and_chunk_data("/content/Rag_data/Talent Relocation Handbook.pdf", is_pdf=True)

# 2. Initialize Chroma Vector Store ------------------------------
gemini_embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
vectorstore = Chroma.from_documents(
    documents=docs,
    embedding=gemini_embeddings,
    persist_directory="./chroma_db"
)

# Load the persisted vector store from disk
vectorstore_disk = Chroma(
    persist_directory="./chroma_db",
    embedding_function=gemini_embeddings
)

# 3. Hybrid Retrieval System --------------------------------------
class HybridRetriever:
    def __init__(self, vector_store, text_field="page_content"):
        self.vector_store = vector_store

        # Extract documents from the vector store using similarity_search
        self.documents = vector_store.similarity_search("", k=1000)  # Fetch all documents
        print(f"Loaded {len(self.documents)} documents from the vector store.")

        # # Extract documents from the vector store
        # self.documents = vector_store.get()["documents"]
        # print(f"Loaded {len(self.documents)} documents from the vector store.")

        # Initialize BM25Retriever with the documents
        self.bm25_retriever = BM25Retriever.from_documents(
            self.documents,
            text_field=text_field
        )
        self.bm25_retriever.k = 5

    def retrieve(self, query: str):
        # Semantic Search
        vector_results = self.vector_store.similarity_search(query, k=5)
        # Keyword Search
        bm25_docs = self.bm25_retriever.get_relevant_documents(query)
        # Combine and deduplicate
        combined = vector_results + bm25_docs
        seen = set()
        return [doc for doc in combined if not (doc.page_content in seen or seen.add(doc.page_content))]


# 3. Reranking System ---------------------------------------------
class Reranker:
    def __init__(self, model_name="BAAI/bge-reranker-base"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)

    def rerank(self, query: str, documents: List[Document], top_n: int = 3):
        pairs = [(query, doc.page_content) for doc in documents]
        features = self.tokenizer(pairs, padding=True, truncation=True, return_tensors="pt")
        scores = self.model(**features).logits
        sorted_indices = scores.argsort(descending=True)
        return [documents[i] for i in sorted_indices[:top_n]]

# 4. LangGraph Workflow -------------------------------------------
class GraphState(TypedDict):
    question: str
    context: List[Document]
    answer: str


def retrieve_nodes(state: GraphState):
    hybrid_retriever = HybridRetriever(vectorstore_disk)
    initial_docs = hybrid_retriever.retrieve(state["question"])
    reranker = Reranker()
    state["context"] = reranker.rerank(state["question"], initial_docs)
    return state

def generate_answer(state: GraphState):
    formatted_docs = "\n\n".join([d.page_content for d in state["context"]])

    prompt = PromptTemplate.from_template("""
    [Advanced RAG System]
    Context: {context}
    ---
    Question: {question}
    Answer in markdown with sources. If unsure, say "I don't know".
    """)

    llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.3)
    chain = prompt | llm | StrOutputParser()
    state["answer"] = chain.invoke({
        "question": state["question"],
        "context": formatted_docs
    })
    return state


# 5. Full Pipeline Execution --------------------------------------
workflow = StateGraph(GraphState)
workflow.add_node("retrieve", retrieve_nodes)
workflow.add_node("generate", generate_answer)
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)


# Run the workflow
app = workflow.compile()
result = app.invoke({"question": "What do I stand to gain with this new company as a talent?"})
print(result["answer"])

### Logging and a mix of Keyword and semantic

Let’s refine and optimize the code further to ensure it’s production-ready and scalable.

Key Improvements to the Code
1. **Error Handling and Logging:**
  - Add robust error handling for API calls, document loading, and retrieval.
  - Implement logging for debugging and monitoring.

2. **Optimized Hybrid Retrieval:**

  - Ensure the BM25 retriever is initialized correctly with all documents.

  - Improve deduplication logic for combined results.

3. **Reranking Efficiency:**

  - Use a lightweight reranking model for faster inference.

  - Add batching for reranking to handle large document sets.

4. **LangGraph Workflow Enhancements:**

  - Add a fallback mechanism for failed retrievals.

  - Include query rewriting or expansion for better retrieval.

5. **Production Readiness:**

  - Add configuration management (e.g., using pydantic or environment variables).

  - Containerize the application for deployment.

In [None]:
import os
import logging
from typing import List, Dict, TypedDict
from concurrent.futures import ThreadPoolExecutor
from langchain import PromptTemplate
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document, StrOutputParser
from langchain.vectorstores import Chroma
from langchain.retrievers import BM25Retriever
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langgraph.graph import END, StateGraph
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from rank_bm25 import BM25Okapi
import numpy as np
from sklearn.cluster import KMeans

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set API key
os.environ["GOOGLE_API_KEY"] = "GEMINI_API_KEY"

# 1. Improved Document Processing for Multiple PDFs --------------------------------
def load_and_chunk_pdf(file_path: str) -> List[Document]:
    """Load and split a single PDF document."""
    try:
        loader = PyPDFLoader(file_path)
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200,
            separators=["\n\n", "\n", " ", ""]
        )
        return loader.load_and_split(text_splitter)
    except Exception as e:
        logger.error(f"Error processing {file_path}: {e}")
        return []

def load_and_chunk_folder(folder_path: str) -> List[Document]:
    """Load and chunk all PDFs in a folder."""
    if not os.path.isdir(folder_path):
        raise ValueError(f"{folder_path} is not a valid directory.")

    pdf_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(".pdf")]
    if not pdf_files:
        logger.warning(f"No PDF files found in {folder_path}.")
        return []

    all_docs = []
    with ThreadPoolExecutor() as executor:
        futures = [executor.submit(load_and_chunk_pdf, pdf_file) for pdf_file in pdf_files]
        for future in futures:
            try:
                docs = future.result()
                all_docs.extend(docs)
            except Exception as e:
                logger.error(f"Error processing PDF: {e}")

    logger.info(f"Loaded and chunked {len(all_docs)} documents from {len(pdf_files)} PDFs.")
    return all_docs

# Process all PDFs in a folder
try:
    folder_path = "/content/Rag_data"  # Replace with your folder path
    docs = load_and_chunk_folder(folder_path)
except Exception as e:
    logger.error(f"Failed to process folder: {e}")
    raise

# 2. Initialize Chroma Vector Store ------------------------------
try:
    gemini_embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
    vectorstore = Chroma.from_documents(
        documents=docs,
        embedding=gemini_embeddings,
        persist_directory="./chroma_db"
    )
    logger.info("Vector store initialized and persisted.")
except Exception as e:
    logger.error(f"Error initializing vector store: {e}")
    raise

# Load the persisted vector store from disk
try:
    vectorstore_disk = Chroma(
        persist_directory="./chroma_db",
        embedding_function=gemini_embeddings
    )
    logger.info("Loaded vector store from disk.")
except Exception as e:
    logger.error(f"Error loading vector store from disk: {e}")
    raise

# 3. Advanced Indexing: Hierarchical Indexing --------------------------------
class HierarchicalIndex:
    def __init__(self, documents: List[Document], embedding_model):
        self.documents = documents
        self.embedding_model = embedding_model
        self.index = self._build_index()

    def _build_index(self):
        """Build a hierarchical index using embeddings."""
        embeddings = self.embedding_model.embed_documents([doc.page_content for doc in self.documents])
        self.index = self._cluster_documents(embeddings)
        return self.index

    def _cluster_documents(self, embeddings, n_clusters=10):
        """Cluster documents into hierarchical groups."""
        kmeans = KMeans(n_clusters=n_clusters)
        clusters = kmeans.fit_predict(embeddings)
        return {i: [] for i in range(n_clusters)}

    def retrieve(self, query: str, top_k: int = 5) -> List[Document]:
        """Retrieve documents from the hierarchical index."""
        query_embedding = self.embedding_model.embed_query(query)
        # For simplicity, return top-k documents (can be enhanced with hierarchical search)
        return self.documents[:top_k]

# 4. Contextual Embeddings --------------------------------
class ContextualEmbeddings:
    def __init__(self, model_name="models/embedding-001"):
        self.model = GoogleGenerativeAIEmbeddings(model=model_name)

    def embed(self, text: str) -> List[float]:
        return self.model.embed_query(text)

    def embed_documents(self, documents: List[str]) -> List[List[float]]:
        return self.model.embed_documents(documents)

# 5. Hybrid Retrieval System --------------------------------
class HybridRetriever:
    def __init__(self, vector_store, bm25_retriever):
        self.vector_store = vector_store
        self.bm25_retriever = bm25_retriever

    def retrieve(self, query: str, top_k: int = 5) -> List[Document]:
        # Semantic Search
        vector_results = self.vector_store.similarity_search(query, k=top_k)
        # Keyword Search
        bm25_docs = self.bm25_retriever.get_relevant_documents(query)
        # Combine and deduplicate
        combined = vector_results + bm25_docs
        seen = set()
        return [doc for doc in combined if not (doc.page_content in seen or seen.add(doc.page_content))]

# 6. Reranking System --------------------------------
class Reranker:
    def __init__(self, model_name="BAAI/bge-reranker-base"):
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
            logger.info(f"Loaded reranking model: {model_name}")
        except Exception as e:
            logger.error(f"Error loading reranking model: {e}")
            raise

    def rerank(self, query: str, documents: List[Document], top_n: int = 3) -> List[Document]:
        try:
            pairs = [(query, doc.page_content) for doc in documents]
            features = self.tokenizer(pairs, padding=True, truncation=True, return_tensors="pt")
            scores = self.model(**features).logits
            sorted_indices = scores.argsort(descending=True)
            return [documents[i] for i in sorted_indices[:top_n]]
        except Exception as e:
            logger.error(f"Error during reranking: {e}")
            return documents[:top_n]  # Fallback to top N documents

# 7. LangGraph Workflow --------------------------------
class GraphState(TypedDict):
    question: str
    context: List[Document]
    answer: str

def retrieve_nodes(state: GraphState) -> GraphState:
    try:
        # Initialize components
        embeddings = ContextualEmbeddings()
        hierarchical_index = HierarchicalIndex(docs, embeddings)
        bm25_retriever = BM25Retriever.from_documents(docs)
        hybrid_retriever = HybridRetriever(vectorstore_disk, bm25_retriever)
        reranker = Reranker()

        # Retrieve documents
        initial_docs = hybrid_retriever.retrieve(state["question"])
        state["context"] = reranker.rerank(state["question"], initial_docs)
        logger.info(f"Retrieved and reranked {len(state['context'])} documents.")
    except Exception as e:
        logger.error(f"Error in retrieve_nodes: {e}")
        state["context"] = []  # Fallback to empty context
    return state

def generate_answer(state: GraphState) -> GraphState:
    try:
        formatted_docs = "\n\n".join([d.page_content for d in state["context"]])
        prompt = PromptTemplate.from_template("""
        [Advanced RAG System]
        Context: {context}
        ---
        Question: {question}
        Answer in markdown with sources. If unsure, say "I don't know".
        """)
        llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.3)
        chain = prompt | llm | StrOutputParser()
        state["answer"] = chain.invoke({
            "question": state["question"],
            "context": formatted_docs
        })
        logger.info("Generated answer successfully.")
    except Exception as e:
        logger.error(f"Error in generate_answer: {e}")
        state["answer"] = "I'm sorry, I couldn't generate an answer. Please try again."
    return state

# 8. Full Pipeline Execution --------------------------------
workflow = StateGraph(GraphState)
workflow.add_node("retrieve", retrieve_nodes)
workflow.add_node("generate", generate_answer)
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)

# Compile the workflow
app = workflow.compile()

# Example Query
try:
    result = app.invoke({"question": "Run a swot analysis and tell me if this is good for me to leave my job to sign up as a professional with this new company"})
    print(result["answer"])
except Exception as e:
    logger.error(f"Error executing workflow: {e}")

In [None]:
try:
    result = app.invoke({"question": "Using a swot analysis tell me about all the reasons not to consider this offer"})
    print(result["answer"])
except Exception as e:
    logger.error(f"Error executing workflow: {e}")

### Creating a streaming output and a validation pipeline using pydantic

In [None]:
import os
import logging
from typing import List, Dict, TypedDict, Optional
from concurrent.futures import ThreadPoolExecutor
from langchain import PromptTemplate
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document, StrOutputParser
from langchain.vectorstores import Chroma
from langchain.retrievers import BM25Retriever
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langgraph.graph import END, StateGraph
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from rank_bm25 import BM25Okapi
import numpy as np
from sklearn.cluster import KMeans
from pydantic import BaseModel, ValidationError
import asyncio

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set API key
os.environ["GOOGLE_API_KEY"] = "AIzaSyC-zFkoipLIFe-2u4ZlamzwT-wZkHzJx-U"

# 1. Pydantic Models for Data Validation --------------------------------
class QueryInput(BaseModel):
    question: str
    top_k: Optional[int] = 5

class QueryOutput(BaseModel):
    answer: str
    sources: List[str]

# 2. Improved Document Processing for Multiple PDFs --------------------------------
def load_and_chunk_pdf(file_path: str) -> List[Document]:
    """Load and split a single PDF document."""
    try:
        loader = PyPDFLoader(file_path)
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200,
            separators=["\n\n", "\n", " ", ""]
        )
        return loader.load_and_split(text_splitter)
    except Exception as e:
        logger.error(f"Error processing {file_path}: {e}")
        return []

def load_and_chunk_folder(folder_path: str) -> List[Document]:
    """Load and chunk all PDFs in a folder."""
    if not os.path.isdir(folder_path):
        raise ValueError(f"{folder_path} is not a valid directory.")

    pdf_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(".pdf")]
    if not pdf_files:
        logger.warning(f"No PDF files found in {folder_path}.")
        return []

    all_docs = []
    with ThreadPoolExecutor() as executor:
        futures = [executor.submit(load_and_chunk_pdf, pdf_file) for pdf_file in pdf_files]
        for future in futures:
            try:
                docs = future.result()
                all_docs.extend(docs)
            except Exception as e:
                logger.error(f"Error processing PDF: {e}")

    logger.info(f"Loaded and chunked {len(all_docs)} documents from {len(pdf_files)} PDFs.")
    return all_docs

# Process all PDFs in a folder
try:
    folder_path = "/content/Rag_data"  # Replace with your folder path
    docs = load_and_chunk_folder(folder_path)
except Exception as e:
    logger.error(f"Failed to process folder: {e}")
    raise

# 3. Initialize Chroma Vector Store ------------------------------
try:
    gemini_embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
    vectorstore = Chroma.from_documents(
        documents=docs,
        embedding=gemini_embeddings,
        persist_directory="./chroma_db"
    )
    logger.info("Vector store initialized and persisted.")
except Exception as e:
    logger.error(f"Error initializing vector store: {e}")
    raise

# Load the persisted vector store from disk
try:
    vectorstore_disk = Chroma(
        persist_directory="./chroma_db",
        embedding_function=gemini_embeddings
    )
    logger.info("Loaded vector store from disk.")
except Exception as e:
    logger.error(f"Error loading vector store from disk: {e}")
    raise

# 4. Advanced Indexing: Hierarchical Indexing --------------------------------
class HierarchicalIndex:
    def __init__(self, documents: List[Document], embedding_model):
        self.documents = documents
        self.embedding_model = embedding_model
        self.index = self._build_index()

    def _build_index(self):
        """Build a hierarchical index using embeddings."""
        embeddings = self.embedding_model.embed_documents([doc.page_content for doc in self.documents])
        self.index = self._cluster_documents(embeddings)
        return self.index

    def _cluster_documents(self, embeddings, n_clusters=10):
        """Cluster documents into hierarchical groups."""
        kmeans = KMeans(n_clusters=n_clusters)
        clusters = kmeans.fit_predict(embeddings)
        return {i: [] for i in range(n_clusters)}

    def retrieve(self, query: str, top_k: int = 5) -> List[Document]:
        """Retrieve documents from the hierarchical index."""
        query_embedding = self.embedding_model.embed_query(query)
        # For simplicity, return top-k documents (can be enhanced with hierarchical search)
        return self.documents[:top_k]

# 5. Contextual Embeddings --------------------------------
class ContextualEmbeddings:
    def __init__(self, model_name="models/embedding-001"):
        self.model = GoogleGenerativeAIEmbeddings(model=model_name)

    def embed(self, text: str) -> List[float]:
        return self.model.embed_query(text)

    def embed_documents(self, documents: List[str]) -> List[List[float]]:
        return self.model.embed_documents(documents)

# 6. Hybrid Retrieval System --------------------------------
class HybridRetriever:
    def __init__(self, vector_store, bm25_retriever):
        self.vector_store = vector_store
        self.bm25_retriever = bm25_retriever

    def retrieve(self, query: str, top_k: int = 5) -> List[Document]:
        # Semantic Search
        vector_results = self.vector_store.similarity_search(query, k=top_k)
        # Keyword Search
        bm25_docs = self.bm25_retriever.get_relevant_documents(query)
        # Combine and deduplicate
        combined = vector_results + bm25_docs
        seen = set()
        return [doc for doc in combined if not (doc.page_content in seen or seen.add(doc.page_content))]

# 7. Reranking System --------------------------------
class Reranker:
    def __init__(self, model_name="BAAI/bge-reranker-base"):
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
            logger.info(f"Loaded reranking model: {model_name}")
        except Exception as e:
            logger.error(f"Error loading reranking model: {e}")
            raise

    def rerank(self, query: str, documents: List[Document], top_n: int = 3) -> List[Document]:
        try:
            pairs = [(query, doc.page_content) for doc in documents]
            features = self.tokenizer(pairs, padding=True, truncation=True, return_tensors="pt")
            scores = self.model(**features).logits
            sorted_indices = scores.argsort(descending=True)
            return [documents[i] for i in sorted_indices[:top_n]]
        except Exception as e:
            logger.error(f"Error during reranking: {e}")
            return documents[:top_n]  # Fallback to top N documents

# 8. LangGraph Workflow --------------------------------
class GraphState(TypedDict):
    question: str
    context: List[Document]
    answer: str

def retrieve_nodes(state: GraphState) -> GraphState:
    try:
        # Initialize components
        embeddings = ContextualEmbeddings()
        hierarchical_index = HierarchicalIndex(docs, embeddings)
        bm25_retriever = BM25Retriever.from_documents(docs)
        hybrid_retriever = HybridRetriever(vectorstore_disk, bm25_retriever)
        reranker = Reranker()

        # Retrieve documents
        initial_docs = hybrid_retriever.retrieve(state["question"])
        state["context"] = reranker.rerank(state["question"], initial_docs)
        logger.info(f"Retrieved and reranked {len(state['context'])} documents.")
    except Exception as e:
        logger.error(f"Error in retrieve_nodes: {e}")
        state["context"] = []  # Fallback to empty context
    return state

async def generate_answer(state: GraphState):
    """Stream the answer in real-time."""
    try:
        formatted_docs = "\n\n".join([d.page_content for d in state["context"]])
        prompt = PromptTemplate.from_template("""
        [Advanced RAG System]
        Context: {context}
        ---
        Question: {question}
        Answer in markdown with sources. If unsure, say "I don't know".
        """)
        llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.3)
        chain = prompt | llm | StrOutputParser()
        response = chain.stream({
            "question": state["question"],
            "context": formatted_docs
        })
        for chunk in response:
            yield chunk
        logger.info("Generated answer successfully.")
    except Exception as e:
        logger.error(f"Error in generate_answer: {e}")
        yield "I'm sorry, I couldn't generate an answer. Please try again."

# 9. Main Function with Pydantic Validation and Streaming --------------------------------
async def main(question: str, top_k: int = 5):
    # Validate input
    try:
        query_input = QueryInput(question=question, top_k=top_k)
    except ValidationError as e:
        logger.error(f"Validation error: {e}")
        return

    # Initialize workflow
    workflow = StateGraph(GraphState)
    workflow.add_node("retrieve", retrieve_nodes)
    workflow.set_entry_point("retrieve")
    workflow.add_edge("retrieve", END)
    app = workflow.compile()



# Example Query
try:
    result = app.invoke({"question": "Run a swot analysis and tell me if this is good for me to leave my job to sign up as a professional with the company"})
    print(result["answer"])
except Exception as e:
    logger.error(f"Error executing workflow: {e}")


### Integrating advanced prompting and Chat History

#### Advanced Prompting

In [None]:
import os
import logging
from typing import List, Dict, TypedDict, Optional
from concurrent.futures import ThreadPoolExecutor
from langchain import PromptTemplate
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document, StrOutputParser
from langchain.vectorstores import Chroma
from langchain.retrievers import BM25Retriever
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langgraph.graph import END, StateGraph
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from rank_bm25 import BM25Okapi
import numpy as np
from sklearn.cluster import KMeans
from pydantic import BaseModel, ValidationError
import asyncio
import nest_asyncio
import json
from collections import Counter

# Apply nest_asyncio for Jupyter/Colab compatibility
nest_asyncio.apply()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set API key
os.environ["GOOGLE_API_KEY"] = "GOOGLE_API_KEY"

# Configuration Management
class Config:
    PDF_FOLDER_PATH = "/content/Rag_data"  # Folder containing PDFs
    CHROMA_PERSIST_DIR = "./chroma_db"  # Directory to persist Chroma vector store
    EMBEDDING_MODEL = "models/embedding-001"  # Gemini embedding model
    RERANKER_MODEL = "BAAI/bge-reranker-base"  # Reranking model
    GENERATION_MODEL = "gemini-2.0-flash"  # Gemini generation model
    TOP_K = 5  # Default number of documents to retrieve
    TOP_N = 3  # Default number of documents to rerank

# 1. Pydantic Models for Data Validation --------------------------------
class QueryInput(BaseModel):
    question: str
    top_k: Optional[int] = Config.TOP_K

class QueryOutput(BaseModel):
    answer: str
    sources: List[str]

# 2. Improved Document Processing for Multiple PDFs --------------------------------
def load_and_chunk_pdf(file_path: str) -> List[Document]:
    """Load and split a single PDF document."""
    try:
        loader = PyPDFLoader(file_path)
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200,
            separators=["\n\n", "\n", " ", ""]
        )
        return loader.load_and_split(text_splitter)
    except Exception as e:
        logger.error(f"Error processing {file_path}: {e}")
        return []

def load_and_chunk_folder(folder_path: str) -> List[Document]:
    """Load and chunk all PDFs in a folder."""
    if not os.path.isdir(folder_path):
        raise ValueError(f"{folder_path} is not a valid directory.")

    pdf_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(".pdf")]
    if not pdf_files:
        logger.warning(f"No PDF files found in {folder_path}.")
        return []

    all_docs = []
    with ThreadPoolExecutor() as executor:
        futures = [executor.submit(load_and_chunk_pdf, pdf_file) for pdf_file in pdf_files]
        for future in futures:
            try:
                docs = future.result()
                all_docs.extend(docs)
            except Exception as e:
                logger.error(f"Error processing PDF: {e}")

    logger.info(f"Loaded and chunked {len(all_docs)} documents from {len(pdf_files)} PDFs.")
    return all_docs

# Process all PDFs in a folder
try:
    docs = load_and_chunk_folder(Config.PDF_FOLDER_PATH)
except Exception as e:
    logger.error(f"Failed to process folder: {e}")
    raise

# 3. Initialize Chroma Vector Store ------------------------------
try:
    gemini_embeddings = GoogleGenerativeAIEmbeddings(model=Config.EMBEDDING_MODEL)
    vectorstore = Chroma.from_documents(
        documents=docs,
        embedding=gemini_embeddings,
        persist_directory=Config.CHROMA_PERSIST_DIR
    )
    logger.info("Vector store initialized and persisted.")
except Exception as e:
    logger.error(f"Error initializing vector store: {e}")
    raise

# Load the persisted vector store from disk
try:
    vectorstore_disk = Chroma(
        persist_directory=Config.CHROMA_PERSIST_DIR,
        embedding_function=gemini_embeddings
    )
    logger.info("Loaded vector store from disk.")
except Exception as e:
    logger.error(f"Error loading vector store from disk: {e}")
    raise

# 4. Advanced Indexing: Hierarchical Indexing --------------------------------
class HierarchicalIndex:
    def __init__(self, documents: List[Document], embedding_model):
        self.documents = documents
        self.embedding_model = embedding_model
        self.index = self._build_index()

    def _build_index(self):
        """Build a hierarchical index using embeddings."""
        embeddings = self.embedding_model.embed_documents([doc.page_content for doc in self.documents])
        self.index = self._cluster_documents(embeddings)
        return self.index

    def _cluster_documents(self, embeddings, n_clusters=10):
        """Cluster documents into hierarchical groups."""
        kmeans = KMeans(n_clusters=n_clusters)
        clusters = kmeans.fit_predict(embeddings)
        return {i: [] for i in range(n_clusters)}

    def retrieve(self, query: str, top_k: int = Config.TOP_K) -> List[Document]:
        """Retrieve documents from the hierarchical index."""
        query_embedding = self.embedding_model.embed_query(query)
        # For simplicity, return top-k documents (can be enhanced with hierarchical search)
        return self.documents[:top_k]

# 5. Contextual Embeddings --------------------------------
class ContextualEmbeddings:
    def __init__(self, model_name=Config.EMBEDDING_MODEL):
        self.model = GoogleGenerativeAIEmbeddings(model=model_name)

    def embed(self, text: str) -> List[float]:
        return self.model.embed_query(text)

    def embed_documents(self, documents: List[str]) -> List[List[float]]:
        return self.model.embed_documents(documents)

# 6. Hybrid Retrieval System --------------------------------
class HybridRetriever:
    def __init__(self, vector_store, bm25_retriever):
        self.vector_store = vector_store
        self.bm25_retriever = bm25_retriever

    def retrieve(self, query: str, top_k: int = Config.TOP_K) -> List[Document]:
        # Semantic Search
        vector_results = self.vector_store.similarity_search(query, k=top_k)
        # Keyword Search
        bm25_docs = self.bm25_retriever.get_relevant_documents(query)
        # Combine and deduplicate
        combined = vector_results + bm25_docs
        seen = set()
        return [doc for doc in combined if not (doc.page_content in seen or seen.add(doc.page_content))]

# 7. Reranking System --------------------------------
class Reranker:
    def __init__(self, model_name=Config.RERANKER_MODEL):
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
            logger.info(f"Loaded reranking model: {model_name}")
        except Exception as e:
            logger.error(f"Error loading reranking model: {e}")
            raise

    def rerank(self, query: str, documents: List[Document], top_n: int = Config.TOP_N) -> List[Document]:
        try:
            pairs = [(query, doc.page_content) for doc in documents]
            features = self.tokenizer(pairs, padding=True, truncation=True, return_tensors="pt")
            scores = self.model(**features).logits
            sorted_indices = scores.argsort(descending=True)
            return [documents[i] for i in sorted_indices[:top_n]]
        except Exception as e:
            logger.error(f"Error during reranking: {e}")
            return documents[:top_n]  # Fallback to top N documents

# 8. Advanced Prompting Strategies --------------------------------
class AdvancedPrompts:
    @staticmethod
    def retrieval_prompt(query: str) -> str:
        """Advanced prompt for retrieval."""
        return f"""
        You are an expert retrieval system. Rewrite the following query to improve retrieval:
        Original Query: {query}
        Rewritten Query:
        """

    @staticmethod
    def metadata_filter_prompt(query: str, metadata: Dict[str, str]) -> str:
        """Advanced prompt for metadata filtering."""
        return f"""
        You are an expert retrieval system. Add metadata to the query for better filtering:
        Original Query: {query}
        Metadata: {metadata}
        Enhanced Query:
        """

    @staticmethod
    def generation_prompt(context: str, question: str) -> str:
        """Advanced prompt for generation."""
        return f"""
        You are an expert in talent management and career development. Use the following context to answer the question.
        Context: {context}
        ---
        Question: {question}
        Answer in markdown with sources. If unsure, say "I don't know".
        """

# 9. Metadata Extraction and Filtering --------------------------------
class MetadataExtractor:
    @staticmethod
    def extract_from_query(query: str) -> Dict[str, str]:
        """Extract metadata from the query using an LLM."""
        prompt = f"""
        Extract metadata from the following query. Return the result as a JSON object with keys like "year", "topic", "location", etc.
        Query: {query}
        Metadata:
        """
        llm = ChatGoogleGenerativeAI(model=Config.GENERATION_MODEL, temperature=0.3)
        response = llm.invoke(prompt).content
        try:
            metadata = json.loads(response)
            return metadata
        except Exception as e:
            logger.error(f"Error parsing metadata: {e}")
            return {}

class MetadataInferrer:
    @staticmethod
    def infer_from_documents(documents: List[Document]) -> Dict[str, str]:
        """Infer metadata from the document set."""
        metadata = {}

        # Example: Infer the most recent year
        years = [doc.metadata.get("year") for doc in documents if "year" in doc.metadata]
        if years:
            metadata["year"] = max(years)

        # Example: Infer the most common topic
        topics = [doc.metadata.get("topic") for doc in documents if "topic" in doc.metadata]
        if topics:
            metadata["topic"] = Counter(topics).most_common(1)[0][0]

        return metadata

def combine_metadata(query_metadata: Dict[str, str], document_metadata: Dict[str, str]) -> Dict[str, str]:
    """Combine metadata from the query and documents."""
    combined = {}
    combined.update(document_metadata)  # Start with document metadata
    combined.update(query_metadata)    # Override with query metadata (if any)
    return combined

# 10. LangGraph Workflow --------------------------------
class GraphState(TypedDict):
    question: str
    context: List[Document]
    answer: str

def retrieve_nodes(state: GraphState) -> GraphState:
    try:
        # Initialize components
        embeddings = ContextualEmbeddings()
        hierarchical_index = HierarchicalIndex(docs, embeddings)
        bm25_retriever = BM25Retriever.from_documents(docs)
        hybrid_retriever = HybridRetriever(vectorstore_disk, bm25_retriever)
        reranker = Reranker()

        # Extract metadata from the query
        query_metadata = MetadataExtractor.extract_from_query(state["question"])

        # Infer metadata from the document set
        document_metadata = MetadataInferrer.infer_from_documents(docs)

        # Combine metadata
        metadata = combine_metadata(query_metadata, document_metadata)

        # Rewrite query with metadata
        llm = ChatGoogleGenerativeAI(model=Config.GENERATION_MODEL, temperature=0.3)
        rewritten_query = llm.invoke(AdvancedPrompts.retrieval_prompt(state["question"])).content

        # Add metadata filtering
        enhanced_query = llm.invoke(AdvancedPrompts.metadata_filter_prompt(rewritten_query, metadata)).content

        # Retrieve documents
        initial_docs = hybrid_retriever.retrieve(enhanced_query)
        state["context"] = reranker.rerank(enhanced_query, initial_docs)
        logger.info(f"Retrieved and reranked {len(state['context'])} documents.")
    except Exception as e:
        logger.error(f"Error in retrieve_nodes: {e}")
        state["context"] = []  # Fallback to empty context
    return state

async def generate_answer(state: GraphState):
    """Stream the answer in real-time."""
    try:
        formatted_docs = "\n\n".join([d.page_content for d in state["context"]])
        prompt = AdvancedPrompts.generation_prompt(formatted_docs, state["question"])
        llm = ChatGoogleGenerativeAI(model=Config.GENERATION_MODEL, temperature=0.3)
        chain = PromptTemplate.from_template(prompt) | llm | StrOutputParser()
        response = chain.stream({
            "question": state["question"],
            "context": formatted_docs
        })
        for chunk in response:
            yield chunk
        logger.info("Generated answer successfully.")
    except Exception as e:
        logger.error(f"Error in generate_answer: {e}")
        yield "I'm sorry, I couldn't generate an answer. Please try again."

# 11. Main Function with Advanced Prompting --------------------------------
async def main(question: str, top_k: int = Config.TOP_K):
    # Validate input
    try:
        query_input = QueryInput(question=question, top_k=top_k)
    except ValidationError as e:
        logger.error(f"Validation error: {e}")
        return

    # Initialize workflow
    workflow = StateGraph(GraphState)
    workflow.add_node("retrieve", retrieve_nodes)
    workflow.set_entry_point("retrieve")
    workflow.add_edge("retrieve", END)
    app = workflow.compile()

    # Run retrieval
    state = app.invoke({"question": query_input.question})

    # Stream the answer
    print("Generating answer...")
    async for chunk in generate_answer(state):
        print(chunk, end="", flush=True)
    print("\nAnswer generation complete.")

# Example Usage
if __name__ == "__main__":
    # Check if an event loop is already running
    try:
        loop = asyncio.get_event_loop()
    except RuntimeError:
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)

    # Run the main function
    loop.run_until_complete(main("Run a SWOT analysis and tell me if this is good for me to leave my job to sign up as a professional with the company"))

#### Chat History with MongoDB

In [None]:
import os
import logging
from typing import List, Dict, TypedDict, Optional
from concurrent.futures import ThreadPoolExecutor
from langchain import PromptTemplate
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document, StrOutputParser
from langchain.vectorstores import Chroma
from langchain.retrievers import BM25Retriever
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langgraph.graph import END, StateGraph
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from rank_bm25 import BM25Okapi
import numpy as np
from sklearn.cluster import KMeans
from pydantic import BaseModel, ValidationError
import asyncio
import nest_asyncio
import json
from collections import Counter
from datetime import datetime
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi

# Apply nest_asyncio for Jupyter/Colab compatibility
nest_asyncio.apply()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set API key
os.environ["GOOGLE_API_KEY"] = "GOOGLE_API_KEY"
uri = "mongo_db_url"


# Configuration Management
class Config:
    PDF_FOLDER_PATH = "/content/Rag_data"  # Folder containing PDFs
    CHROMA_PERSIST_DIR = "./chroma_db"  # Directory to persist Chroma vector store
    EMBEDDING_MODEL = "models/embedding-001"  # Gemini embedding model
    RERANKER_MODEL = "BAAI/bge-reranker-base"  # Reranking model
    GENERATION_MODEL = "gemini-2.0-flash"  # Gemini generation model
    TOP_K = 5  # Default number of documents to retrieve
    TOP_N = 3  # Default number of documents to rerank

# Initialize MongoDB client
# uri = "mongo_client_uri"
client = MongoClient(uri, server_api=ServerApi('1'))
db = client["chat_history_db"]
chats_collection = db["chats"]

def store_chat_history(userID: str, chatID: str, role: str, content: str):
    """Store a message in the chat history."""
    try:
        message = {
            "userID": userID,
            "chatID": chatID,
            "timestamp": datetime.now(),
            "role": role,
            "content": content
        }
        chats_collection.insert_one(message)
        logger.info(f"Stored message in chat history for user {userID}, chat {chatID}.")
    except Exception as e:
        logger.error(f"Error storing chat history: {e}")

def retrieve_chat_history(userID: str, chatID: str, limit: int = 5) -> List[Dict[str, str]]:
    """Retrieve the most recent messages from the chat history."""
    try:
        messages = chats_collection.find(
            {"userID": userID, "chatID": chatID},
            sort=[("timestamp", -1)],  # Sort by timestamp in descending order
            limit=limit
        )
        return [{"role": msg["role"], "content": msg["content"]} for msg in messages]
    except Exception as e:
        logger.error(f"Error retrieving chat history: {e}")
        return []

# Pydantic Models for Data Validation
class QueryInput(BaseModel):
    question: str
    top_k: Optional[int] = Config.TOP_K

class QueryOutput(BaseModel):
    answer: str
    sources: List[str]

# Improved Document Processing for Multiple PDFs
def load_and_chunk_pdf(file_path: str) -> List[Document]:
    """Load and split a single PDF document."""
    try:
        loader = PyPDFLoader(file_path)
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200,
            separators=["\n\n", "\n", " ", ""]
        )
        return loader.load_and_split(text_splitter)
    except Exception as e:
        logger.error(f"Error processing {file_path}: {e}")
        return []

def load_and_chunk_folder(folder_path: str) -> List[Document]:
    """Load and chunk all PDFs in a folder."""
    if not os.path.isdir(folder_path):
        raise ValueError(f"{folder_path} is not a valid directory.")

    pdf_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(".pdf")]
    if not pdf_files:
        logger.warning(f"No PDF files found in {folder_path}.")
        return []

    all_docs = []
    with ThreadPoolExecutor() as executor:
        futures = [executor.submit(load_and_chunk_pdf, pdf_file) for pdf_file in pdf_files]
        for future in futures:
            try:
                docs = future.result()
                all_docs.extend(docs)
            except Exception as e:
                logger.error(f"Error processing PDF: {e}")

    logger.info(f"Loaded and chunked {len(all_docs)} documents from {len(pdf_files)} PDFs.")
    return all_docs

# Process all PDFs in a folder
try:
    docs = load_and_chunk_folder(Config.PDF_FOLDER_PATH)
except Exception as e:
    logger.error(f"Failed to process folder: {e}")
    raise

# Initialize Chroma Vector Store
try:
    gemini_embeddings = GoogleGenerativeAIEmbeddings(model=Config.EMBEDDING_MODEL)
    vectorstore = Chroma.from_documents(
        documents=docs,
        embedding=gemini_embeddings,
        persist_directory=Config.CHROMA_PERSIST_DIR
    )
    logger.info("Vector store initialized and persisted.")
except Exception as e:
    logger.error(f"Error initializing vector store: {e}")
    raise

# Load the persisted vector store from disk
try:
    vectorstore_disk = Chroma(
        persist_directory=Config.CHROMA_PERSIST_DIR,
        embedding_function=gemini_embeddings
    )
    logger.info("Loaded vector store from disk.")
except Exception as e:
    logger.error(f"Error loading vector store from disk: {e}")
    raise

# Advanced Indexing: Hierarchical Indexing
class HierarchicalIndex:
    def __init__(self, documents: List[Document], embedding_model):
        self.documents = documents
        self.embedding_model = embedding_model
        self.index = self._build_index()

    def _build_index(self):
        """Build a hierarchical index using embeddings."""
        embeddings = self.embedding_model.embed_documents([doc.page_content for doc in self.documents])
        self.index = self._cluster_documents(embeddings)
        return self.index

    def _cluster_documents(self, embeddings, n_clusters=10):
        """Cluster documents into hierarchical groups."""
        kmeans = KMeans(n_clusters=n_clusters)
        clusters = kmeans.fit_predict(embeddings)
        return {i: [] for i in range(n_clusters)}

    def retrieve(self, query: str, top_k: int = Config.TOP_K) -> List[Document]:
        """Retrieve documents from the hierarchical index."""
        query_embedding = self.embedding_model.embed_query(query)
        # For simplicity, return top-k documents (can be enhanced with hierarchical search)
        return self.documents[:top_k]

# Contextual Embeddings
class ContextualEmbeddings:
    def __init__(self, model_name=Config.EMBEDDING_MODEL):
        self.model = GoogleGenerativeAIEmbeddings(model=model_name)

    def embed(self, text: str) -> List[float]:
        return self.model.embed_query(text)

    def embed_documents(self, documents: List[str]) -> List[List[float]]:
        return self.model.embed_documents(documents)

# Hybrid Retrieval System
class HybridRetriever:
    def __init__(self, vector_store, bm25_retriever):
        self.vector_store = vector_store
        self.bm25_retriever = bm25_retriever

    def retrieve(self, query: str, top_k: int = Config.TOP_K) -> List[Document]:
        # Semantic Search
        vector_results = self.vector_store.similarity_search(query, k=top_k)
        # Keyword Search
        bm25_docs = self.bm25_retriever.get_relevant_documents(query)
        # Combine and deduplicate
        combined = vector_results + bm25_docs
        seen = set()
        return [doc for doc in combined if not (doc.page_content in seen or seen.add(doc.page_content))]

# Reranking System
class Reranker:
    def __init__(self, model_name=Config.RERANKER_MODEL):
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
            logger.info(f"Loaded reranking model: {model_name}")
        except Exception as e:
            logger.error(f"Error loading reranking model: {e}")
            raise

    def rerank(self, query: str, documents: List[Document], top_n: int = Config.TOP_N) -> List[Document]:
        try:
            pairs = [(query, doc.page_content) for doc in documents]
            features = self.tokenizer(pairs, padding=True, truncation=True, return_tensors="pt")
            scores = self.model(**features).logits
            sorted_indices = scores.argsort(descending=True)
            return [documents[i] for i in sorted_indices[:top_n]]
        except Exception as e:
            logger.error(f"Error during reranking: {e}")
            return documents[:top_n]  # Fallback to top N documents

# Advanced Prompting Strategies
class AdvancedPrompts:
    @staticmethod
    def retrieval_prompt(query: str) -> str:
        """Advanced prompt for retrieval."""
        return f"""
        You are an expert retrieval system. Rewrite the following query to improve retrieval:
        Original Query: {query}
        Rewritten Query:
        """

    @staticmethod
    def metadata_filter_prompt(query: str, metadata: Dict[str, str]) -> str:
        """Advanced prompt for metadata filtering."""
        return f"""
        You are an expert retrieval system. Add metadata to the query for better filtering:
        Original Query: {query}
        Metadata: {metadata}
        Enhanced Query:
        """

    @staticmethod
    def generation_prompt(context: str, question: str) -> str:
        """Advanced prompt for generation."""
        return f"""
        You are an expert in talent management and career development. Use the following context to answer the question.
        Context: {context}
        ---
        Question: {question}
        Answer in markdown with sources. If unsure, say "I don't know".
        """

    @staticmethod
    def can_answer_from_history_prompt(query: str, chat_history: List[Dict[str, str]]) -> str:
        """Prompt to check if the query can be answered from chat history."""
        return f"""
        Can the following query be answered using the provided chat history? Answer with "yes" or "no".
        Query: {query}
        Chat History:
        {chat_history}
        Answer:
        """

    @staticmethod
    def enhance_query_with_history_prompt(query: str, chat_history: List[Dict[str, str]]) -> str:
        """Prompt to enhance the query using chat history."""
        return f"""
        Rewrite the following query to include relevant context from the chat history only if necessary:
        Query: {query}
        Chat History:
        {chat_history}
        Enhanced Query:
        """

    @staticmethod
    def answer_from_history_prompt(chat_history: List[Dict[str, str]], question: str) -> str:
        """Prompt to generate an answer from chat history."""
        return f"""
        You are an expert in talent management and career development. Use the following chat history to answer the question.
        Chat History:
        {chat_history}
        ---
        Question: {question}
        Answer in markdown with sources. If unsure, say "I don't know".
        """

# Metadata Extraction and Filtering
class MetadataExtractor:
    @staticmethod
    def extract_from_query(query: str) -> Dict[str, str]:
        """Extract metadata from the query using an LLM."""
        prompt = f"""
        Extract metadata from the following query. Return the result as a JSON object with keys like "year", "topic", "location", etc.
        Query: {query}
        Metadata:
        """
        llm = ChatGoogleGenerativeAI(model=Config.GENERATION_MODEL, temperature=0.3)
        response = llm.invoke(prompt).content
        try:
            metadata = json.loads(response)
            return metadata
        except Exception as e:
            logger.error(f"Error parsing metadata: {e}")
            return {}

class MetadataInferrer:
    @staticmethod
    def infer_from_documents(documents: List[Document]) -> Dict[str, str]:
        """Infer metadata from the document set."""
        metadata = {}

        # Example: Infer the most recent year
        years = [doc.metadata.get("year") for doc in documents if "year" in doc.metadata]
        if years:
            metadata["year"] = max(years)

        # Example: Infer the most common topic
        topics = [doc.metadata.get("topic") for doc in documents if "topic" in doc.metadata]
        if topics:
            metadata["topic"] = Counter(topics).most_common(1)[0][0]

        return metadata

def combine_metadata(query_metadata: Dict[str, str], document_metadata: Dict[str, str]) -> Dict[str, str]:
    """Combine metadata from the query and documents."""
    combined = {}
    combined.update(document_metadata)  # Start with document metadata
    combined.update(query_metadata)    # Override with query metadata (if any)
    return combined

# LangGraph Workflow
class GraphState(TypedDict):
    question: str
    context: List[Document]
    answer: str

def retrieve_nodes(state: GraphState) -> GraphState:
    try:
        # Initialize components
        embeddings = ContextualEmbeddings()
        hierarchical_index = HierarchicalIndex(docs, embeddings)
        bm25_retriever = BM25Retriever.from_documents(docs)
        hybrid_retriever = HybridRetriever(vectorstore_disk, bm25_retriever)
        reranker = Reranker()

        # Extract metadata from the query
        query_metadata = MetadataExtractor.extract_from_query(state["question"])

        # Infer metadata from the document set
        document_metadata = MetadataInferrer.infer_from_documents(docs)

        # Combine metadata
        metadata = combine_metadata(query_metadata, document_metadata)

        # Rewrite query with metadata
        llm = ChatGoogleGenerativeAI(model=Config.GENERATION_MODEL, temperature=0.3)
        rewritten_query = llm.invoke(AdvancedPrompts.retrieval_prompt(state["question"])).content

        # Add metadata filtering
        enhanced_query = llm.invoke(AdvancedPrompts.metadata_filter_prompt(rewritten_query, metadata)).content

        # Retrieve documents
        initial_docs = hybrid_retriever.retrieve(enhanced_query)
        state["context"] = reranker.rerank(enhanced_query, initial_docs)
        logger.info(f"Retrieved and reranked {len(state['context'])} documents.")
    except Exception as e:
        logger.error(f"Error in retrieve_nodes: {e}")
        state["context"] = []  # Fallback to empty context
    return state

async def generate_answer(state: GraphState):
    """Stream the answer and store it in state['answer']."""
    try:
        formatted_docs = "\n\n".join([d.page_content for d in state["context"]])
        prompt = AdvancedPrompts.generation_prompt(formatted_docs, state["question"])
        llm = ChatGoogleGenerativeAI(model=Config.GENERATION_MODEL, temperature=0.3)
        chain = PromptTemplate.from_template(prompt) | llm | StrOutputParser()
        response = chain.stream({
            "question": state["question"],
            "context": formatted_docs
        })

        # Stream the answer and build the full response
        full_response = ""
        print("Generating answer...")
        for chunk in response:
            print(chunk, end="", flush=True)  # Stream to the user
            full_response += chunk  # Build the full response
        print("\nAnswer generation complete.")

        # Store the full response in state["answer"]
        state["answer"] = full_response
        return state
    except Exception as e:
        logger.error(f"Error in generate_answer: {e}")
        state["answer"] = "I'm sorry, I couldn't generate an answer. Please try again."
        return state

def can_answer_from_history(query: str, chat_history: List[Dict[str, str]]) -> bool:
    """Check if the query can be answered from the chat history."""
    prompt = AdvancedPrompts.can_answer_from_history_prompt(query, chat_history)
    llm = ChatGoogleGenerativeAI(model=Config.GENERATION_MODEL, temperature=0.3)
    response = llm.invoke(prompt).content.strip().lower()
    return response == "yes"

def enhance_query_with_history(query: str, chat_history: List[Dict[str, str]]) -> str:
    """Enhance the query using chat history."""
    prompt = AdvancedPrompts.enhance_query_with_history_prompt(query, chat_history)
    llm = ChatGoogleGenerativeAI(model=Config.GENERATION_MODEL, temperature=0.3)
    return llm.invoke(prompt).content

# Main Function with Advanced Prompting
async def main(userID: str, chatID: str, question: str, top_k: int = Config.TOP_K):
    # Validate input
    try:
        query_input = QueryInput(question=question, top_k=top_k)
    except ValidationError as e:
        logger.error(f"Validation error: {e}")
        return

    # Retrieve chat history
    chat_history = retrieve_chat_history(userID, chatID)
    logger.info(f"Retrieved chat history for user {userID}, chat {chatID}.")

    # Check if the query can be answered from chat history
    if can_answer_from_history(query_input.question, chat_history):
        logger.info("Answering from chat history.")
        # Generate answer from chat history
        prompt = AdvancedPrompts.answer_from_history_prompt(chat_history, query_input.question)
        llm = ChatGoogleGenerativeAI(model=Config.GENERATION_MODEL, temperature=0.3)
        chain = PromptTemplate.from_template(prompt) | llm | StrOutputParser()
        response = chain.stream({
            "question": query_input.question,
            "context": chat_history
        })
        for chunk in response:
            print(chunk, end="", flush=True)
        print("\nAnswer generation complete.")
        return

    # Enhance query with chat history
    enhanced_query = enhance_query_with_history(query_input.question, chat_history)
    logger.info(f"Enhanced query: {enhanced_query}")

    # Initialize workflow
    workflow = StateGraph(GraphState)
    workflow.add_node("retrieve", retrieve_nodes)
    workflow.set_entry_point("retrieve")
    workflow.add_edge("retrieve", END)
    app = workflow.compile()

    # Run retrieval
    state = app.invoke({"question": enhanced_query})

    # Generate the answer (stream and store in state["answer"])
    state = await generate_answer(state)

    # Store the question and answer in chat history
    store_chat_history(userID, chatID, "user", query_input.question)
    store_chat_history(userID, chatID, "assistant", state["answer"])
    logger.info("Stored chat history.")

# Example Usage
if __name__ == "__main__":
    userID = "user123"
    chatID = "chat456"
    question = "Run a SWOT analysis and tell me if this is good for me to leave my job to sign up as a professional with the new company"

    # Check if an event loop is already running
    try:
        loop = asyncio.get_event_loop()
    except RuntimeError:
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)

    # Run the main function
    loop.run_until_complete(main(userID, chatID, question))

## Now that we are sure our pipeline will work let's convert it into a full blown application prototype. 

Further upgrade
- Optimize for speed
- Finetune a custom model
- Explore other strategies to develop your system
- Develop a full web and mobile App and optimize the UI UX and AI functionality for target audience