In [7]:
# Standard library imports
import os
import asyncio
import logging
from pathlib import Path
from typing import Tuple, Any, List, Dict, Optional
from dataclasses import dataclass
from asyncio import sleep
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type # To implement exponential backoff
import httpx
from dotenv import load_dotenv

# Embeddings and LLM imports
from langchain_cohere import CohereEmbeddings
from langchain_openai import ChatOpenAI

# Vector store imports 
from langchain_community.vectorstores import FAISS

# Document Loader Imports
from langchain_community.document_loaders import (
    TextLoader, 
    UnstructuredMarkdownLoader,
    JSONLoader,
    UnstructuredHTMLLoader,
    PyPDFLoader
)
from langchain_core.documents import Document

# Prompt and template imports
from langchain.prompts import ChatPromptTemplate, PromptTemplate

# Langchain  runnables and pipeline imports
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser
from langchain_core.runnables import RunnableParallel, RunnableLambda

# Callbacks and logging imports
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.schema import LLMResult

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Configure environment variables
load_dotenv()


True

In [6]:
@dataclass
class SelfRAGResponse:
    """Complete self rag with reflection"""
    answer: str
    retrieved_docs: List[Document]
    reflection_score: float
    needs_retrieval: bool
    citations: List[str]
    retrieval_decision_reasoning: str

class RateLimitCallbackHandler(AsyncCallbackHandler):
    """Callback handler to manage API rate limiting with semaphores"""
    
    def __init__(self, semaphore: asyncio.Semaphore):
        self.semaphore = semaphore
        
    async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
        await self.semaphore.acquire()
        
    async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        self.semaphore.release()


In [8]:
class RateLimitedCohereEmbeddings:
    """Wrapper for Cohere embeddings with rate limiting and retry logic"""
    def __init__(self, model: str, cohere_api_key: str, max_concurrent: int = 2, delay_between_calls: float = 0.5, batch_size: int = 30):
        self.base_embeddings = CohereEmbeddings(
            model = model,
            cohere_api_key = cohere_api_key
        )

        self.semaphore = asyncio.Semaphore(max_concurrent)
        self.delay_between_calls = delay_between_calls
        self.batch_size = batch_size


    @retry(
        stop=stop_after_attempt(5),
        wait=wait_exponential(multiplier=1, min=1, max=60),
        retry=retry_if_exception_type((httpx.HTTPStatusError, Exception))
    )

    async def _embed_with_retry(self, texts:List[str])-> List[List[float]]:
        """Embed texts with retry logic"""
        async with self.semaphore:
            # Ensure minimum delay between calls
            current_time = asyncio.get_event_loop().time()
            time_since_last_call = current_time - self.last_call_time
            if time_since_last_call < self.delay_between_calls:
                await sleep(self.delay_between_calls - time_since_last_call)
            
            try:
                logger.debug(f"Embedding batch of {len(texts)} texts")
                result = await self.base_embeddings.aembed_documents(texts)
                self.last_call_time = asyncio.get_event_loop().time()
                return result
                
            except Exception as e:
                if "429" in str(e) or "Too Many Requests" in str(e):
                    logger.warning(f"Rate limit hit, retrying after delay...")
                    await sleep(2)  # Additional delay for rate limits
                    raise
                else:
                    logger.error(f"Embedding error: {e}")
                    raise
    
    async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed documents with batching and rate limiting"""
        if not texts:
            return []
        
        # Use configurable batch size
        all_embeddings = []
        total_batches = (len(texts) + self.batch_size - 1) // self.batch_size
        
        logger.info(f"Processing {len(texts)} texts in {total_batches} batches of {self.batch_size}")
        
        for i in range(0, len(texts), self.batch_size):
            batch_num = (i // self.batch_size) + 1
            batch = texts[i:i + self.batch_size]
            
            logger.info(f"Processing batch {batch_num}/{total_batches} ({len(batch)} texts)")
            
            batch_embeddings = await self._embed_with_retry(batch)
            all_embeddings.extend(batch_embeddings)
            
            # Delay between batches (except for the last one)
            if i + self.batch_size < len(texts):
                await sleep(self.delay_between_calls)
        
        return all_embeddings
    
    async def aembed_query(self, text: str) -> List[float]:
        """Embed a single query"""
        async with self.semaphore:
            current_time = asyncio.get_event_loop().time()
            time_since_last_call = current_time - self.last_call_time
            if time_since_last_call < self.delay_between_calls:
                await sleep(self.delay_between_calls - time_since_last_call)
            
            result = await self.base_embeddings.aembed_query(text)
            self.last_call_time = asyncio.get_event_loop().time()
            return result


In [None]:
class DocumentLoader:
    def __init__(self):
        self.loaders = {
            ".txt": TextLoader,
            ".md": UnstructuredMarkdownLoader,
            ".json": self._create_json_loader,
            ".pdf": PyPDFLoader,
            ".html": UnstructuredHTMLLoader,
            ".py": TextLoader,
            ".js": TextLoader,
            ".css": TextLoader
        }

    def _create_json_loader(self, file_path: str):
        """Create JSON loader with custom jq_schema"""
        return JSONLoader(
            file_path=file_path,
            jq_schema='.[]',
            text_content=False
        )

    async def load_documents(self, kb_folder: str) -> List[Document]:
        """Load all documents from the knowledge base folder"""
        documents = []
        kb_path = Path(kb_folder)

        if not kb_path.exists():
            raise FileNotFoundError(f"Knowledge base folder not found: {kb_path}")

        for file_path in kb_path.glob("**/*"):
            if file_path.is_file() and file_path.suffix.lower() in self.loaders:
                try:
                    loader_class = self.loaders[file_path.suffix.lower()]

                    if file_path.suffix.lower() == ".json":
                        loader = loader_class(str(file_path))
                    else:
                        loader = loader_class(str(file_path))

                    docs = loader.load()

                    # Add metadata
                    for doc in docs:
                        doc.metadata.update({
                            'file_path': str(file_path),
                            'file_type': file_path.suffix,
                            'file_name': file_path.name
                        })

                    documents.extend(docs)
                
                except Exception as e:
                    logger.warning(f"There was an error loading the knowledge base: {str(e)}")
                    # Fallback to TextLoader for unknown formats
                    try:
                        loader = TextLoader(str(file_path))
                        docs = loader.load()
                        # Add metadata
                        for doc in docs:
                            doc.metadata.update({
                                'file_path': str(file_path),
                                'file_type': file_path.suffix,
                                'file_name': file_path.name
                            })

                        documents.extend(docs)

                    except Exception as fallback_error:
                        logger.error(f"Failed to load {file_path} with fallback: {fallback_error}")
        
        logger.info(f"Loaded {len(documents)} documents from {kb_folder}")
        return documents