In [None]:
from flask import Flask, request, jsonify
from flask_cors import CORS
import os
import json
import logging
from datetime import datetime
from typing import Dict, List, Any
import asyncio
import uuid

# Import our RAG system
from multilingual_rag import MultilingualRAG, DocumentChunk

# For advanced LLM integration
import openai
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.llms import Ollama
from langchain.schema import HumanMessage, SystemMessage

# For database integration
import psycopg2
from pymongo import MongoClient
import pinecone

In [None]:
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class AdvancedRAGAPI:
    """Advanced RAG API with multiple LLM providers and vector databases"""

    def __init__(self, config_path: str = "rag_config.json"):
        self.app = Flask(__name__)
        CORS(self.app)

        # Load configuration
        self.config = self.load_config(config_path)

        # Initialize RAG system
        self.rag = MultilingualRAG(
            vector_store_path=self.config.get("vector_store_path", "vectorstore")
        )

        # Initialize LLM providers
        self.llm_providers = self.initialize_llm_providers()

        # Initialize vector database
        self.vector_db = self.initialize_vector_database()

        # Session management
        self.sessions = {}

        # Setup routes
        self.setup_routes()

    def load_config(self, config_path: str) -> Dict[str, Any]:
        """Load configuration from JSON file"""
        default_config = {
            "llm_providers": {
                "openai": {"api_key": "", "model": "gpt-3.5-turbo"},
                "google": {"api_key": "", "model": "gemini-pro"},
                "ollama": {"base_url": "http://localhost:11434", "model": "llama2"}
            },
            "vector_database": {
                "type": "faiss",  # faiss, pinecone, postgres, mongodb
                "connection_params": {}
            },
            "retrieval_params": {
                "top_k": 5,
                "similarity_threshold": 0.3
            },
            "generation_params": {
                "max_tokens": 300,
                "temperature": 0.1
            }
        }

        try:
            with open(config_path, 'r', encoding='utf-8') as f:
                config = json.load(f)
            return {**default_config, **config}
        except FileNotFoundError:
            logger.warning(f"Config file {config_path} not found. Using default config.")
            return default_config

    def initialize_llm_providers(self) -> Dict[str, Any]:
        """Initialize different LLM providers"""
        providers = {}

        # OpenAI
        if self.config["llm_providers"]["openai"]["api_key"]:
            try:
                openai.api_key = self.config["llm_providers"]["openai"]["api_key"]
                providers["openai"] = ChatOpenAI(
                    api_key=self.config["llm_providers"]["openai"]["api_key"],
                    model=self.config["llm_providers"]["openai"]["model"],
                    temperature=self.config["generation_params"]["temperature"]
                )
                logger.info("OpenAI provider initialized")
            except Exception as e:
                logger.error(f"Failed to initialize OpenAI: {e}")

        # Google Gemini
        if self.config["llm_providers"]["google"]["api_key"]:
            try:
                providers["google"] = ChatGoogleGenerativeAI(
                    google_api_key=self.config["llm_providers"]["google"]["api_key"],
                    model=self.config["llm_providers"]["google"]["model"],
                    temperature=self.config["generation_params"]["temperature"]
                )
                logger.info("Google Gemini provider initialized")
            except Exception as e:
                logger.error(f"Failed to initialize Google Gemini: {e}")

        # Ollama (local)
        try:
            providers["ollama"] = Ollama(
                base_url=self.config["llm_providers"]["ollama"]["base_url"],
                model=self.config["llm_providers"]["ollama"]["model"]
            )
            logger.info("Ollama provider initialized")
        except Exception as e:
            logger.error(f"Failed to initialize Ollama: {e}")

        return providers

    def initialize_vector_database(self):
        """Initialize vector database based on configuration"""
        db_type = self.config["vector_database"]["type"]

        if db_type == "pinecone":
            return self.setup_pinecone()
        elif db_type == "postgres":
            return self.setup_postgres()
        elif db_type == "mongodb":
            return self.setup_mongodb()
        else:
            # Default to FAISS (already in RAG system)
            return None

    def setup_pinecone(self):
        """Setup Pinecone vector database"""
        try:
            api_key = self.config["vector_database"]["connection_params"].get("api_key")
            if api_key:
                pinecone.init(api_key=api_key)
                index_name = self.config["vector_database"]["connection_params"].get("index_name", "rag-index")
                return pinecone.Index(index_name)
        except Exception as e:
            logger.error(f"Failed to setup Pinecone: {e}")
        return None

    def setup_postgres(self):
        """Setup PostgreSQL with pgvector"""
        try:
            params = self.config["vector_database"]["connection_params"]
            conn = psycopg2.connect(
                host=params.get("host", "localhost"),
                database=params.get("database", "rag_db"),
                user=params.get("user", "postgres"),
                password=params.get("password", "")
            )
            return conn
        except Exception as e:
            logger.error(f"Failed to setup PostgreSQL: {e}")
        return None

    def setup_mongodb(self):
        """Setup MongoDB with vector search"""
        try:
            params = self.config["vector_database"]["connection_params"]
            client = MongoClient(params.get("connection_string", "mongodb://localhost:27017/"))
            db = client[params.get("database", "rag_db")]
            return db
        except Exception as e:
            logger.error(f"Failed to setup MongoDB: {e}")
        return None

    def setup_routes(self):
        """Setup Flask routes"""

        @self.app.route('/health', methods=['GET'])
        def health_check():
            """Health check endpoint"""
            return jsonify({
                "status": "healthy",
                "timestamp": datetime.now().isoformat(),
                "version": "1.0.0"
            })

        @self.app.route('/chat', methods=['POST'])
        def chat():
            """Main chat endpoint"""
            try:
                data = request.get_json()

                # Validate input
                if not data or 'query' not in data:
                    return jsonify({"error": "Missing 'query' in request body"}), 400

                query = data['query']
                session_id = data.get('session_id', str(uuid.uuid4()))
                llm_provider = data.get('llm_provider', 'openai')
                include_sources = data.get('include_sources', True)

                # Process query
                result = self.process_query(
                    query=query,
                    session_id=session_id,
                    llm_provider=llm_provider,
                    include_sources=include_sources
                )

                return jsonify(result)

            except Exception as e:
                logger.error(f"Error in chat endpoint: {e}")
                return jsonify({"error": str(e)}), 500

        @self.app.route('/upload', methods=['POST'])
        def upload_document():
            """Upload and process new documents"""
            try:
                if 'file' not in request.files:
                    return jsonify({"error": "No file uploaded"}), 400

                file = request.files['file']
                if file.filename == '':
                    return jsonify({"error": "No file selected"}), 400

                # Save uploaded file
                upload_dir = "uploads"
                os.makedirs(upload_dir, exist_ok=True)
                file_path = os.path.join(upload_dir, file.filename)
                file.save(file_path)

                # Process document
                self.rag.build_knowledge_base([file_path])

                return jsonify({
                    "message": "Document uploaded and processed successfully",
                    "filename": file.filename,
                    "stats": self.rag.get_stats()
                })

            except Exception as e:
                logger.error(f"Error in upload endpoint: {e}")
                return jsonify({"error": str(e)}), 500

        @self.app.route('/stats', methods=['GET'])
        def get_stats():
            """Get system statistics"""
            try:
                stats = self.rag.get_stats()
                stats["llm_providers"] = list(self.llm_providers.keys())
                stats["active_sessions"] = len(self.sessions)
                return jsonify(stats)
            except Exception as e:
                logger.error(f"Error in stats endpoint: {e}")
                return jsonify({"error": str(e)}), 500

        @self.app.route('/session/<session_id>', methods=['GET'])
        def get_session(session_id):
            """Get session history"""
            try:
                session = self.sessions.get(session_id, {"history": []})
                return jsonify(session)
            except Exception as e:
                logger.error(f"Error in session endpoint: {e}")
                return jsonify({"error": str(e)}), 500

        @self.app.route('/session/<session_id>', methods=['DELETE'])
        def clear_session(session_id):
            """Clear session history"""
            try:
                if session_id in self.sessions:
                    del self.sessions[session_id]
                return jsonify({"message": "Session cleared"})
            except Exception as e:
                logger.error(f"Error clearing session: {e}")
                return jsonify({"error": str(e)}), 500

    def process_query(self, query: str, session_id: str, llm_provider: str = "openai",
                     include_sources: bool = True) -> Dict[str, Any]:
        """Process a query and return structured response"""

        start_time = datetime.now()

        # Retrieve relevant chunks
        retrieved_chunks = self.rag.retrieve_relevant_chunks(
            query,
            k=self.config["retrieval_params"]["top_k"]
        )

        # Filter by similarity threshold
        threshold = self.config["retrieval_params"]["similarity_threshold"]
        filtered_chunks = [
            (chunk, score) for chunk, score in retrieved_chunks
            if score >= threshold
        ]

        # Generate response using specified LLM provider
        response = self.generate_response_with_provider(
            query, filtered_chunks, llm_provider
        )

        # Update session
        if session_id not in self.sessions:
            self.sessions[session_id] = {"history": [], "created_at": start_time.isoformat()}

        interaction = {
            "timestamp": start_time.isoformat(),
            "query": query,
            "response": response,
            "llm_provider": llm_provider,
            "retrieved_chunks": len(filtered_chunks),
            "processing_time": (datetime.now() - start_time).total_seconds()
        }

        self.sessions[session_id]["history"].append(interaction)

        # Prepare response
        result = {
            "session_id": session_id,
            "query": query,
            "response": response,
            "metadata": {
                "llm_provider": llm_provider,
                "retrieved_chunks": len(filtered_chunks),
                "processing_time": interaction["processing_time"],
                "timestamp": start_time.isoformat()
            }
        }

        if include_sources:
            result["sources"] = [
                {
                    "text": chunk.text[:200] + "..." if len(chunk.text) > 200 else chunk.text,
                    "source": chunk.source,
                    "page": chunk.page_number,
                    "similarity_score": float(score)
                }
                for chunk, score in filtered_chunks[:3]  # Top 3 sources
            ]

        return result

    def generate_response_with_provider(self, query: str, retrieved_chunks: List,
                                      llm_provider: str) -> str:
        """Generate response using specified LLM provider"""

        if not retrieved_chunks:
            return "দুঃখিত, আপনার প্রশ্নের উত্তর আমার জ্ঞানভাণ্ডারে খুঁজে পাইনি।" if self.is_bengali(query) else "I couldn't find relevant information to answer your question."

        # Prepare context
        context = "\n\n".join([
            f"[Score: {score:.3f}] {chunk.text}"
            for chunk, score in retrieved_chunks
        ])

        # Detect query language
        is_bengali = self.is_bengali(query)

        # System prompt
        system_prompt = """You are a helpful multilingual assistant specializing in Bengali literature and academic content.
        Answer questions based strictly on the provided context.
        If the question is in Bengali, answer in Bengali. If in English, answer in English.
        Provide direct, accurate answers. If the answer is not in the context, say so politely."""

        user_prompt = f"""Context from documents:
{context}

Question: {query}

Please provide a direct and accurate answer based on the context."""

        # Try specified provider first
        if llm_provider in self.llm_providers:
            try:
                provider = self.llm_providers[llm_provider]

                if llm_provider == "ollama":
                    # Ollama direct call
                    full_prompt = f"{system_prompt}\n\n{user_prompt}"
                    response = provider.invoke(full_prompt)
                    return response.strip()
                else:
                    # LangChain providers
                    messages = [
                        SystemMessage(content=system_prompt),
                        HumanMessage(content=user_prompt)
                    ]
                    response = provider.invoke(messages)
                    return response.content.strip()

            except Exception as e:
                logger.error(f"Error with {llm_provider}: {e}")

        # Fallback to built-in RAG response
        return self.rag._simple_response_generation(query, retrieved_chunks, 'bn' if is_bengali else 'en')

    def is_bengali(self, text: str) -> bool:
        """Check if text contains Bengali characters"""
        import re
        return bool(re.search(r'[\u0980-\u09FF]', text))

    def run(self, host='0.0.0.0', port=5000, debug=False):
        """Run the Flask server"""
        logger.info(f"Starting RAG API server on {host}:{port}")
        self.app.run(host=host, port=port, debug=debug)

# Configuration file generator
def generate_config_file():
    """Generate a sample configuration file"""
    config = {
        "llm_providers": {
            "openai": {
                "api_key": "your-openai-api-key",
                "model": "gpt-3.5-turbo"
            },
            "google": {
                "api_key": "your-google-api-key",
                "model": "gemini-pro"
            },
            "ollama": {
                "base_url": "http://localhost:11434",
                "model": "llama2"
            }
        },
        "vector_database": {
            "type": "faiss",
            "connection_params": {
                "pinecone": {
                    "api_key": "your-pinecone-api-key",
                    "index_name": "rag-index"
                },
                "postgres": {
                    "host": "localhost",
                    "database": "rag_db",
                    "user": "postgres",
                    "password": "password"
                },
                "mongodb": {
                    "connection_string": "mongodb://localhost:27017/",
                    "database": "rag_db"
                }
            }
        },
        "vector_store_path": "hsc_bangla_vectorstore",
        "retrieval_params": {
            "top_k": 5,
            "similarity_threshold": 0.3
        },
        "generation_params": {
            "max_tokens": 300,
            "temperature": 0.1
        }
    }

    with open("rag_config.json", 'w', encoding='utf-8') as f:
        json.dump(config, f, indent=2, ensure_ascii=False)

    print("Configuration file 'rag_config.json' generated successfully!")


In [None]:
# Main execution
if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="RAG API Server")
    parser.add_argument("--config", default="rag_config.json", help="Configuration file path")
    parser.add_argument("--host", default="0.0.0.0", help="Host address")
    parser.add_argument("--port", type=int, default=5000, help="Port number")
    parser.add_argument("--debug", action="store_true", help="Enable debug mode")
    parser.add_argument("--generate-config", action="store_true", help="Generate sample config file")

    args = parser.parse_args()

    if args.generate_config:
        generate_config_file()
    else:
        # Initialize and run API server
        api = AdvancedRAGAPI(config_path=args.config)
        api.run(host=args.host, port=args.port, debug=args.debug)