In [3]:
import os
import asyncio
import logging 
from pathlib import Path
from typing import Tuple, Any, List, Dict, Optional
from dataclasses import dataclass
from dotenv import load_dotenv

from langchain_cohere import CohereEmbeddings
from langchain_openai import ChatOpenAI
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import (
    TextLoader, 
    UnstructuredMarkdownLoader,
    JSONLoader,
    UnstructuredHTMLLoader
)

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser
from langchain_core.runnables import RunnableParallel, RunnableLambda
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.schema import LLMResult

load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [4]:
@dataclass
class SelfRAGResponse:
    """Complete self rag with reflection"""
    answer: str
    retrieved_docs: List[Document]
    reflection_score: float
    needs_retrieval: bool
    citations: List[str]
    retrieval_decision_reasoning: str


class RateLimitCallback(AsyncCallbackHandler):
    """Callback handler to manage API rate limiting with semaphores"""
    
    def __init__(self, semaphore: asyncio.Semaphore):
        self.semaphore = semaphore
        
    async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
        await self.semaphore.acquire()
        
    async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        self.semaphore.release()