In [1]:
import os
import asyncio
import logging 
from pathlib import Path
from typing import Tuple, Any, List, Dict, Optional
from dataclasses import dataclass
from dotenv import load_dotenv

from langchain_cohere import CohereEmbeddings
from langchain_openai import ChatOpenAI
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import (
    TextLoader, 
    UnstructuredMarkdownLoader,
    JSONLoader,
    UnstructuredHTMLLoader,
    PyPDFLoader
)

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser
from langchain_core.runnables import RunnableParallel, RunnableLambda
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.schema import LLMResult

load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [2]:
@dataclass
class SelfRAGResponse:
    """Complete self rag with reflection"""
    answer: str
    retrieved_docs: List[Document]
    reflection_score: float
    needs_retrieval: bool
    citations: List[str]
    retrieval_decision_reasoning: str


class RateLimitCallback(AsyncCallbackHandler):
    """Callback handler to manage API rate limiting with semaphores"""
    
    def __init__(self, semaphore: asyncio.Semaphore):
        self.semaphore = semaphore
        
    async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
        await self.semaphore.acquire()
        
    async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        self.semaphore.release()

In [3]:
class DocumentLoader:
    def __init__(self):
        self.loaders = {
            '.txt': TextLoader,
            '.md': UnstructuredMarkdownLoader,
            '.json': self._create_json_loader,
            '.html': UnstructuredHTMLLoader,
            '.py': TextLoader,
            '.js': TextLoader,
            '.css': TextLoader,
            '.pdf': PyPDFLoader
        }

    def _create_json_loader(self, file_path: str):
        """Create JSON loader with custom jq_schema"""
        return JSONLoader(
            file_path=file_path,
            jq_schema='.[]',
            text_content=False
        )

    async def load_documents(self, kb_folder: str) -> List[Document]:
        """Load all documents from the knowledge base folder"""
        documents = []
        kb_path = Path(kb_folder)

        if not kb_path.exists():
            raise FileNotFoundError(f"Knowledge base folder not found: {kb_path}")

        for file_path in kb_path.glob("**/*"):
            if file_path.is_file() and file_path.suffix.is_lower() in self.loaders:
                try:
                    loader_class = self.loaders[file_path.suffix.is_lower()]

                    if file_path.suffix.lower() == ".json":
                        loader = loader_class(str(file_path))
                    else:
                        loader = loader_class(str(file_path))

                    docs = loader.load()

                    # Add metadata
                    for doc in docs:
                        doc.metadata.update({
                            'file_path': str(file_path),
                            'file_type': file_path.suffix,
                            'file_name': file_path.name
                        })

                    documents.extend(docs)
                
                except Exception as e:
                    logger.warning(f"There was an error loading the knowledge base: {str(e)}")
                    # Fallback to TextLoader for unknown formats
                    try:
                        loader = TextLoader(str(file_path))
                        docs = loader.load()
                        # Add metadata
                        for doc in docs:
                            doc.metadata.update({
                                'file_path': str(file_path),
                                'file_type': file_path.suffix,
                                'file_name': file_path.name
                            })

                        documents.extend(docs)

                    except Exception as fallback_error:
                        logger.error(f"Failed to load {file_path} with fallback: {fallback_error}")
        
        logger.info(f"Loaded {len(documents)} documents from {kb_folder}")
        return documents

In [None]:
@dataclass
class RAGSystem:
    cohere_api_key: str
    openrouter_api_key: str
    kb_folder: str
    max_concurrent_requests: int = 5
    chunk_size: int = 2000
    chunk_overlap: int = 200

    def __post_init__(self):
        # Initialize the components
        self.embeddings = CohereEmbeddings(model = "embed-v4.0",
                                         cohere_api_key = os.getenv("COHERE_API_KEY"))

        self.llm = ChatOpenAI(
            model="meta-llama/llama-3.3-70b-instruct",
            openai_api_key=openrouter_api_key,
            openai_api_base="https://openrouter.ai/api/v1",
            temperature=0.6,
            max_tokens=1500
        )

        # Text Splitter
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size = self.chunk_size,
            chunk_overlap = self.chunk_overlap,
            length_function = len
        )

        document_loader = DocumentLoader()

        # Vector store
        self.vector_store = Optional[FAISS] = None
        self.kb_folder = kb_folder

        # Semaphores for rate limiting
        self.llm_semaphore = Semaphore(max_concurrent_requests)
        self.embeddings_semaphore = Semaphore(max_concurrent_requests)

        self.is_initialized = False

        # Set up prompts
        self._setup_prompts()

    def _setup_prompts(self):
        """Set up prompts for different stages"""

        # Retrieval decision prompt
        self.retrieval_decision_prompt = PromptTemplate(
            input_variables=["query"],
            template="""
            Analyze the following query to determine if it requires external knowledge retrieval.
            
            Query: "{query}"
            
            Consider:
            1. Does this query ask for specific facts, data, or domain-specific information?
            2. Would the answer benefit from external documents or knowledge base?
            3. Is this asking about general knowledge that can be answered without retrieval?
            4. Does it require recent or specialized information?
            
            Provide your reasoning and then answer with either "RETRIEVE" or "NO_RETRIEVE".
            
            Reasoning: [Explain your decision]
            Decision: [RETRIEVE or NO_RETRIEVE]
            """
        )

        # Answer generation with retrieval prompt
        self.rag_prompt = ChatPromptTemplate.from_template("""
            You are a helpful AI assistant. Use the following context documents to answer the user's question accurately and comprehensively.
            
            Context Documents:
            {context}
            
            Question: {question}
            
            Instructions:
            - Base your answer primarily on the provided context
            - If the context doesn't contain sufficient information, acknowledge this
            - Cite specific documents when referencing information
            - Be accurate, detailed, and helpful
            - If you need to use general knowledge to supplement the context, clearly indicate this
            
            Answer:
        """)

