In [17]:
#!pip install langchain_community
#!pip install langchain_anthropic
#!pip install langchain_huggingface langchain_chroma
#!pip install sentence_transformers langchain_openai
#!pip install pypdf
# !pip install langchain_openai

In [2]:
import os
from dotenv import load_dotenv
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader, WebBaseLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_groq import ChatGroq
from langchain_anthropic import ChatAnthropic
from langchain_chroma import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.output_parsers import StrOutputParser

USER_AGENT environment variable not set, consider setting it to identify your requests.


In [3]:
load_dotenv()

True

In [1]:
class MultiProviderRAG:
    def __init__(self,
                 provider="groq",  # "groq" or "claude"
                 model_configs=None,
                 embedding_model="text-embedding-3-small",
                 chunk_size=1200,
                 chunk_overlap=200,
                 collection_name="documents"):

        # Default model configurations
        if model_configs is None:
            model_configs = {
                "groq": {
                    "model": "llama-3.1-70b-versatile",
                    "temperature": 0.1
                },
                "claude": {
                    "model": "claude-3-haiku-20240307",
                    "temperature": 0.2
                }
            }

        self.provider = provider
        self.model_configs = model_configs

        # Initialize the selected LLM
        self.llm = self._initialize_llm(provider)

        # Use OpenAI embeddings (consistent across providers)
        self.embeddings = OpenAIEmbeddings(
            model=embedding_model,
            api_key=os.getenv("OPENAI_API_KEY")
        )

        # Configure text splitter
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            separators=["\n\n", "\n", ". ", "? ", "! ", " ", ""],
            add_start_index=True
        )

        self.collection_name = collection_name
        self.vector_store = None
        self.retriever = None

        # Provider-specific prompt templates
        self.prompt_templates = {
            "groq": ChatPromptTemplate.from_template("""
            You are a helpful AI assistant analyzing documents. Answer the question based on the provided context.

            Be direct and informative:
            - Use the context to provide accurate answers
            - Quote relevant passages when helpful
            - If context is insufficient, explain what's missing
            - Keep responses focused and practical

            Context:
            {context}

            Question: {question}

            Answer:
            """),

            "claude": ChatPromptTemplate.from_template("""
            You are a thoughtful document analyst. Provide a comprehensive answer based on the given context.

            Guidelines:
            - Analyze the context thoroughly before responding
            - Provide nuanced insights where appropriate
            - Reference specific details from the documents
            - Acknowledge uncertainty when context is limited
            - Structure your response clearly

            Document Context:
            {context}

            Question: {question}

            Analysis:
            """)
        }

    def _initialize_llm(self, provider):
        """Initialize LLM based on provider choice"""
        config = self.model_configs[provider]

        if provider == "groq":
            return ChatGroq(
                model=config["model"],
                api_key=os.getenv("GROQ_API_KEY"),
                temperature=config["temperature"]
            )
        elif provider == "claude":
            return ChatAnthropic(
                model=config["model"],
                api_key=os.getenv("ANTHROPIC_API_KEY"),
                temperature=config["temperature"]
            )
        else:
            raise ValueError(f"Unsupported provider: {provider}")

    def switch_provider(self, new_provider):
        """Switch between Groq and Claude"""
        if new_provider not in ["groq", "claude"]:
            raise ValueError("Provider must be 'groq' or 'claude'")

        self.provider = new_provider
        self.llm = self._initialize_llm(new_provider)
        print(f"Switched to {new_provider} ({self.model_configs[new_provider]['model']})")

    def load_documents(self, source, source_type="pdf"):
        """Load documents from various sources"""
        if source_type == "pdf":
            if isinstance(source, str):
                source = [source]

            all_documents = []
            for file_path in source:
                loader = PyPDFLoader(file_path)
                documents = loader.load()
                all_documents.extend(documents)
                print(f"Loaded {len(documents)} pages from {file_path}")

            return all_documents

        elif source_type == "web":
            if isinstance(source, str):
                source = [source]

            loader = WebBaseLoader(source)
            documents = loader.load()
            print(f"Loaded {len(documents)} web documents")
            return documents

        elif source_type == "directory":
            loader = DirectoryLoader(
                source,
                glob="**/*.pdf",
                loader_cls=PyPDFLoader,
                show_progress=True
            )
            documents = loader.load()
            print(f"Loaded {len(documents)} documents from directory")
            return documents

        else:
            raise ValueError("source_type must be 'pdf', 'web', or 'directory'")

    def setup_vector_store(self, documents, persist_directory="./rag_db"):
        """Create vector database"""
        text_chunks = self.text_splitter.split_documents(documents)
        print(f"Created {len(text_chunks)} text chunks")

        self.vector_store = Chroma.from_documents(
            documents=text_chunks,
            embedding=self.embeddings,
            collection_name=self.collection_name,
            persist_directory=persist_directory
        )

        self.retriever = self.vector_store.as_retriever(
            search_type="similarity",
            search_kwargs={"k": 6}
        )

        return len(text_chunks)

    def load_vector_store(self, persist_directory="./rag_db"):
        """Load existing vector database"""
        self.vector_store = Chroma(
            collection_name=self.collection_name,
            embedding_function=self.embeddings,
            persist_directory=persist_directory
        )

        self.retriever = self.vector_store.as_retriever(
            search_type="similarity",
            search_kwargs={"k": 6}
        )

        print("Loaded existing vector database")

    def query(self, question):
        """Query documents using current provider"""
        if not self.retriever:
            raise ValueError("Vector store not initialized")

        # Get provider-specific prompt
        prompt_template = self.prompt_templates[self.provider]

        # Create analysis chain
        chain = (
            {
                "context": self.retriever | RunnableLambda(self._format_context),
                "question": RunnablePassthrough()
            }
            | prompt_template
            | self.llm
            | StrOutputParser()
        )

        return chain.invoke(question)

    def compare_providers(self, question):
        """Compare responses from both providers"""
        if not self.retriever:
            raise ValueError("Vector store not initialized")

        results = {}
        original_provider = self.provider

        for provider in ["groq", "claude"]:
            try:
                self.switch_provider(provider)
                response = self.query(question)
                results[provider] = {
                    "model": self.model_configs[provider]["model"],
                    "response": response
                }
            except Exception as e:
                results[provider] = {
                    "model": self.model_configs[provider]["model"],
                    "response": f"Error: {str(e)}"
                }

        # Restore original provider
        self.switch_provider(original_provider)

        return results

    def _format_context(self, documents):
        """Format retrieved documents"""
        if not documents:
            return "No relevant documents found."

        formatted = []
        for i, doc in enumerate(documents, 1):
            source = doc.metadata.get('source', 'Unknown')
            page = doc.metadata.get('page', 'N/A')

            formatted.append(
                f"[Doc {i} - {source}, p.{page}]\n{doc.page_content}\n"
            )

        return "\n---\n".join(formatted)

    def get_system_info(self):
        """Get current system configuration"""
        return {
            "current_provider": self.provider,
            "current_model": self.model_configs[self.provider]["model"],
            "available_providers": list(self.model_configs.keys()),
            "vector_store_loaded": self.vector_store is not None,
            "embedding_model": "text-embedding-3-small"
        }

# Usage example
def demo():
    # Initialize with Groq as default
    rag = MultiProviderRAG(
        provider="groq",
        model_configs={
            "groq": {
                "model": "llama-3.1-70b-versatile",
                "temperature": 0.1
            },
            "claude": {
                "model": "claude-3-haiku-20240307",
                "temperature": 0.2
            }
        }
    )

    try:
        # Load documents
        documents = rag.load_documents(["/home/ojas/Downloads/sample.pdf"], source_type="pdf")

        # Create vector store
        rag.setup_vector_store(documents, "./multi_provider_db")

        # Test question
        question = "What are the main topics covered in these documents?"

        # Query with Groq
        print("=== GROQ RESPONSE ===")
        groq_response = rag.query(question)
        print(groq_response)

        # Switch to Claude
        rag.switch_provider("claude")
        print("\n=== CLAUDE RESPONSE ===")
        claude_response = rag.query(question)
        print(claude_response)

        # Compare both providers
        print("\n=== PROVIDER COMPARISON ===")
        comparison = rag.compare_providers("Summarize the key findings")

        for provider, result in comparison.items():
            print(f"\n{provider.upper()} ({result['model']}):")
            print(result['response'])

        # Show system info
        print(f"\nSystem Info: {rag.get_system_info()}")

    except Exception as e:
        print(f"Error: {e}")
        print("\nRequired environment variables:")
        print("- GROQ_API_KEY")
        print("- ANTHROPIC_API_KEY")
        print("- OPENAI_API_KEY (for embeddings)")

In [2]:
def demo():
    # Initialize with Groq as default
    rag = MultiProviderRAG(
        provider="groq",
        model_configs={
            "groq": {
                "model": "llama-3.1-70b-versatile",
                "temperature": 0.1
            },
            "claude": {
                "model": "claude-3-haiku-20240307",
                "temperature": 0.2
            }
        }
    )

    try:
        # Load documents
        documents = rag.load_documents(["sample.pdf"], source_type="pdf")

        # Create vector store
        rag.setup_vector_store(documents, "./multi_provider_db")

        # Test question
        question = "What are the main topics covered in these documents?"

        # Query with Groq
        print("=== GROQ RESPONSE ===")
        groq_response = rag.query(question)
        print(groq_response)

        # Switch to Claude
        rag.switch_provider("claude")
        print("\n=== CLAUDE RESPONSE ===")
        claude_response = rag.query(question)
        print(claude_response)

        # Compare both providers
        print("\n=== PROVIDER COMPARISON ===")
        comparison = rag.compare_providers("Summarize the key findings")

        for provider, result in comparison.items():
            print(f"\n{provider.upper()} ({result['model']}):")
            print(result['response'])

        # Show system info
        print(f"\nSystem Info: {rag.get_system_info()}")

    except Exception as e:
        print(f"Error: {e}")
        print("\nRequired environment variables:")
        print("- GROQ_API_KEY")
        print("- ANTHROPIC_API_KEY")
        print("- OPENAI_API_KEY (for embeddings)")

In [3]:
if __name__ == "__main__":
    demo()

NameError: name 'ChatGroq' is not defined