In [None]:
import traceback
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from app.RAG.document_loader import DocumentLoader
from app.RAG.chunker_factory import ChunkerFactory
from app.RAG.embedings_geneartor import EnhancedEmbeddingGenerator
from app.RAG.vector_store import VectorStoreFactory
from app.RAG.vector_search import SearchFactory
from app.RAG.rag_types import StandardRAG, GraphRAG, AdaptiveRAG, RaptorRAG, CorrectiveRAG, IterativeRAG
from app.RAG.prompt_optimizer import PromptOptimizer
from app.RAG.reranking_model import RerankingModel
from app.RAG.fact_checker import FactChecker

class RAGPipeline:
    def __init__(self, rag_type: str, llm_model: str, embedding_model: str, 
                 chunking_option: str, vector_db: str, search_option: str, database_name: str = None, 
                 database_id: str = None, load_only: bool = False):
        """
        Initialize RAG pipeline with provided configurations
        
        Args:
            rag_type (str): Type of RAG implementation to use
            llm_model (str): Language model to use
            embedding_model (str): Embedding model to use
            chunking_option (str): Method for chunking documents
            vector_db (str): Vector database to use
            search_option (str): Search algorithm to use
            database_name (str): Name of the database
            database_id (str): ID of the database
            load_only (bool): If True, only load existing vector store without creating new components
        """
        self.cleanup_required = False
        self.rag_type = rag_type
        self.llm_model = llm_model
        self.embedding_model = embedding_model
        self.chunking_option = chunking_option
        self.vector_db = vector_db
        self.search_option = search_option
        self.database_name = database_name
        self.database_id = database_id
        
        try:
            collection_name = None
            if database_name:
                collection_name = f"rag-{database_name.lower()}-{database_id or 'temp'}"
            # When load_only is True, we'll still initialize all required components
            # but we'll load the vector store instead of creating a new one
            if load_only:
                print(f"???? Initializing RAG Pipeline in load-only mode")
                print(f"???? Database: {database_name}")
                print(f"???? Vector DB: {vector_db}")
                
                # Load the vector store first
                print(f"???? Loading existing Vector Store: {vector_db}...")
                if self.load_vector_store():
                    print("? Vector Store loaded successfully")
                else:
                    raise ValueError(f"Failed to load vector store for {database_name}")
                
                # Initialize language model for queries
                print(f"???? Initializing Language Model: {llm_model}...")
                self.tokenizer = AutoTokenizer.from_pretrained(llm_model)
                self.model = AutoModelForSeq2SeqLM.from_pretrained(llm_model)
                self.generator = pipeline(
                    "text2text-generation",
                    model=self.model,
                    tokenizer=self.tokenizer,
                    max_length=512
                )
                print("? Language Model initialized successfully")
                
                # Initialize other required components
                print(f"???? Initializing Search Algorithm: {search_option}...")
                self.search_algorithm = SearchFactory.create_search_algorithm(search_option)
                print("? Search Algorithm initialized successfully")
                
                print(f"???? Initializing Embedding Generator with model: {embedding_model}...")
                self.embedding_generator = EnhancedEmbeddingGenerator(
                    model_name=embedding_model,
                    batch_size=32
                )
                print("? Embedding Generator initialized successfully")
                
                # Initialize minimal components needed for RAG implementations
                print("???? Initializing Document Loader...")
                self.document_loader = DocumentLoader()
                print("? Document Loader initialized successfully")

                print(f"???? Initializing Chunker with option: {chunking_option}...")
                max_tokens = 100
                token_overlap = 10
                self.chunker = ChunkerFactory.create_chunker(chunking_option, max_tokens, token_overlap)
                print("? Chunker initialized successfully")
                
                # Initialize fact checker for certain RAG types
                if rag_type in ["corrective"]:
                    self.fact_checker = FactChecker()
                
                # Initialize RAG implementation
                self.initialize_rag_implementation()
                
                print("???? RAG Pipeline loaded in query-only mode!")
                return
            
            # Normal initialization for creating new vector stores
            collection_name = None
            if database_name:
                collection_name = f"rag-{database_name.lower()}-{database_id or 'temp'}"
            print("collection_name", collection_name)
            print("database_name", database_name)
            print(f"???? Initializing RAG Pipeline")
            print(f"  RAG Type: {rag_type}")
            print(f"  LLM Model: {llm_model}")
            print(f"  Embedding Model: {embedding_model}")
            print(f"  Chunking Option: {chunking_option}")
            print(f"  Vector DB: {vector_db}")
            print(f"  Search Option: {search_option}")

            # Handle iterative and corrective RAG types
            if rag_type in ["iterative", "corrective"]:
                rag_type = "graph"
                print(f"???? Note: {rag_type} RAG will use Graph RAG implementation.")
            
            # Set fixed confidence threshold for adaptive RAG
            confidence_threshold = 0.7 if rag_type == "adaptive" else None

            # Set fixed RAPTOR parameters
            raptor_params = {}
            if rag_type == "raptor":
                raptor_params = {
                    'token_weight_threshold': 0.5,
                    'max_prompt_attempts': 3
                }

            # Set fixed reranking model for RAPTOR
            reranking_model_name = None
            if rag_type == "raptor":
                reranking_model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"

            # Handle GTE model case
            if embedding_model == "Alibaba-NLP/gte-large-en-v1.5":
                embedding_model = "intfloat/e5-large-v2"

            # Set fixed values for tokens and overlap
            max_tokens = 100
            token_overlap = 10

            # Initialize components
            print("???? Initializing Document Loader...")
            self.document_loader = DocumentLoader()
            print("? Document Loader initialized successfully")

            print(f"???? Initializing Chunker with option: {chunking_option}...")
            self.chunker = ChunkerFactory.create_chunker(chunking_option, max_tokens, token_overlap)
            print("? Chunker initialized successfully")

            print(f"???? Initializing Embedding Generator with model: {embedding_model}...")
            self.embedding_generator = EnhancedEmbeddingGenerator(
                model_name=embedding_model,
                batch_size=32
            )
            print("? Embedding Generator initialized successfully")

            print(f"???? Initializing Vector Store: {vector_db}...")
            # self.vector_store = VectorStoreFactory.create_store(
            #     store_type=vector_db,
            #     config={
            #         "embedding_model": embedding_model,
            #         "collection_name": collection_name,
            #         "database_name": database_name
            #     }
            # )

            self.vector_store = VectorStoreFactory.create_store(
                store_type=vector_db,
                embedding_model=embedding_model,
                collection_name=collection_name,
                database_name = database_name
                
            )
            print("? Vector Store initialized successfully")

            print(f"???? Initializing Search Algorithm: {search_option}...")
            self.search_algorithm = SearchFactory.create_search_algorithm(search_option)
            print("? Search Algorithm initialized successfully")
            
            # Initialize language model
            print(f"???? Initializing Language Model: {llm_model}...")
            self.tokenizer = AutoTokenizer.from_pretrained(llm_model)
            self.model = AutoModelForSeq2SeqLM.from_pretrained(llm_model)
            self.generator = pipeline(
                "text2text-generation",
                model=self.model,
                tokenizer=self.tokenizer,
                max_length=512
            )
            print("? Language Model initialized successfully")

            self.fact_checker = FactChecker()
            
            # Initialize RAG implementation
            self.initialize_rag_implementation()
            print("???? RAG Pipeline initialized completely!")
                
        except Exception as e:
            print(f"? Error during initialization: {str(e)}")
            traceback.print_exc()
            raise
    
    def initialize_rag_implementation(self):
        """Initialize the appropriate RAG implementation based on rag_type"""
        print(f"???? Initializing RAG Implementation: {self.rag_type}...")
        
        base_params = {
            'document_loader': self.document_loader,
            'chunker': self.chunker,
            'embedding_generator': self.embedding_generator,
            'vector_store': self.vector_store,
            'search_algorithm': self.search_algorithm,
            'generator': self.generator
        }
        
        if self.rag_type == "standard":
            self.rag_implementation = StandardRAG(**base_params)
        elif self.rag_type == "graph":
            self.rag_implementation = GraphRAG(**base_params)
        elif self.rag_type == "corrective":
            if not hasattr(self, 'fact_checker'):
                self.fact_checker = FactChecker()
            base_params['fact_checker'] = self.fact_checker
            self.rag_implementation = CorrectiveRAG(**base_params)
        elif self.rag_type == "iterative":
            self.rag_implementation = IterativeRAG(**base_params)
        elif self.rag_type == "adaptive":
            self.rag_implementation = AdaptiveRAG(
                confidence_threshold=0.7,
                **base_params
            )
        elif self.rag_type == "raptor":
            reranking_model = RerankingModel(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2")
            prompt_optimizer = PromptOptimizer(
                embedding_generator=self.embedding_generator,
                tokenizer=self.tokenizer
            )
            
            raptor_params = {
                'token_weight_threshold': 0.5,
                'max_prompt_attempts': 3,
                'reranking_model': reranking_model,
                'prompt_optimizer': prompt_optimizer
            }
            
            self.rag_implementation = RaptorRAG(**raptor_params, **base_params)
        else:
            raise ValueError(f"Unsupported RAG type: {self.rag_type}")
        
        print("? RAG Implementation initialized successfully")

    def process_documents(self, zip_path: str):
        """
        Process documents through the selected RAG implementation
        
        Args:
            zip_path (str): Path to the ZIP file containing documents
            
        Returns:
            List[Document]: List of processed documents
        """
        try:
            print(f"???? Starting document processing for: {zip_path}")
            documents = self.rag_implementation.process_documents(zip_path)
            
            if documents:
                print(f"? Successfully processed {len(documents)} documents")
                print(f"????? Documents loaded into vector store")
                self.cleanup_required = True
            else:
                print("?? No documents were processed")
            
            return documents
        except Exception as e:
            print(f"? Error processing documents: {str(e)}")
            traceback.print_exc()
            self.cleanup()
            return []

    



    def generate_response(self, query: str, max_length: int = 512) -> str:
        """
        Generate response using the selected RAG implementation
        
        Args:
            query (str): User query
            max_length (int): Maximum length of generated response
            
        Returns:
            str: Generated response
        """
        try:
            print(f"? Generating response for query: {query[:50]}...")
            
            # For loaded RAG databases, we may need to retrieve context directly
            if hasattr(self, 'vector_store') and not hasattr(self.rag_implementation, 'generate_response'):
                print("???? Using direct context retrieval for loaded RAG database")
                
                # Retrieve relevant context
                context_docs = self.rag_implementation.retrieve_context(query, max_docs=3)
                
                if context_docs:
                    context_text = "\n\n".join([doc.page_content for doc in context_docs])
                    print(f"? Retrieved {len(context_docs)} context documents")
                    
                    # Construct prompt with context
                    prompt = f"""Context information:
    {context_text}

    Given the context information, please respond to the following:
    {query}"""
                    
                    # Generate response with context
                    response = self.generator(prompt, max_length=max_length)[0]['generated_text']
                else:
                    print("?? No relevant context found, generating response without RAG")
                    response = self.generator(query, max_length=max_length)[0]['generated_text']
                    
                return response
                
            # Use the RAG implementation's generate_response if available
            elif hasattr(self.rag_implementation, 'generate_response'):
                response = self.rag_implementation.generate_response(query, max_length)
                
                if response:
                    print("? Response generated successfully")
                else:
                    print("?? No response generated")
                
                return response
            else:
                # Fallback to direct generation without RAG
                print("?? No RAG implementation found, using direct generation")
                response = self.generator(query, max_length=max_length)[0]['generated_text']
                return response
                
        except Exception as e:
            print(f"? Error generating response: {str(e)}")
            import traceback
            traceback.print_exc()
            return f"Error generating response: {str(e)}"



    def cleanup(self):
        """Cleanup resources"""
        if self.cleanup_required:
            try:
                print("???? Starting cleanup process...")
                if hasattr(self, 'document_loader') and self.document_loader:
                    self.document_loader.cleanup()
                print("? Cleanup completed successfully")
            except Exception as e:
                print(f"? Error during cleanup: {str(e)}")

    def __del__(self):
        """Destructor to ensure cleanup"""
        self.cleanup()



    def retrieve_context(self, query: str, max_docs: int = 3):
        """
        Retrieve relevant context for a given query using the appropriate RAG implementation
        
        Args:
            query (str): The user query
            max_docs (int): Maximum number of documents to retrieve
            
        Returns:
            List[Document]: List of relevant documents
        """
        try:
            # Create a Document object with the query as page_content
            from langchain_core.documents import Document
            query_doc = Document(page_content=query)
            
            # Generate embeddings for the query using the embedding_generator
            query_embedding = self.embedding_generator.generate_embeddings([query_doc])[0]
            
            if self.vector_db == "chromadb":
                import chromadb
                
                # Create persistent client
                client = chromadb.PersistentClient(path="./chroma_db")
                
                # Generate collection name
                collection_names = [
                    f"rag_{self.database_name.lower()}_Temp",
                    f"rag-{self.database_name.lower()}-{self.database_id or 'temp'}",
                    f"rag_{self.database_name.lower()}_{self.database_id}"
                ]
                
                for collection_name in collection_names:
                    try:
                        # Get the collection
                        collection = client.get_collection(name=collection_name)
                        
                        # Query the collection
                        results = collection.query(
                            query_embeddings=[query_embedding],
                            n_results=max_docs
                        )
                        
                        # Convert results to document format
                        documents = []
                        for i, (doc_content, metadata) in enumerate(zip(
                            results.get('documents', [[]])[0],
                            results.get('metadatas', [[]])[0] 
                        )):
                            documents.append(Document(
                                page_content=doc_content,
                                metadata=metadata or {}
                            ))
                        
                        return documents
                        
                    except Exception as e:
                        print(f"Error with ChromaDB collection {collection_name}: {str(e)}")
                        continue
                
                print("? No suitable ChromaDB collection found")
                return []
                        
            elif self.vector_db == "pinecone":
                from pinecone import Pinecone
                
                # Initialize Pinecone client
                pc = Pinecone(api_key="your_pinecone_api_key")  # Replace with actual API key
                
                # Try different naming patterns
                index_names = [
                    f"rag-{self.database_name.lower()}-temp",
                    f"rag_{self.database_name.lower()}_{self.database_id}",
                    f"rag-{self.database_name.lower()}-{self.database_id}"
                ]
                
                for index_name in index_names:
                    try:
                        # Get the index
                        index = pc.Index(index_name)
                        
                        # Query Pinecone index
                        results = index.query(
                            vector=query_embedding,
                            top_k=max_docs,
                            include_metadata=True
                        )
                        
                        # Convert results to document format
                        documents = []
                        for match in results.get('matches', []):
                            content = match.get('metadata', {}).get('content', '')
                            metadata = {k: v for k, v in match.get('metadata', {}).items() if k != 'content'}
                            documents.append(Document(
                                page_content=content,
                                metadata=metadata
                            ))
                            
                        return documents
                        
                    except Exception as e:
                        print(f"Error with Pinecone index {index_name}: {str(e)}")
                        continue
                
                print("? No suitable Pinecone index found")
                return []
                        
            elif self.vector_db == "weaviate":
                import weaviate
                
                # Initialize Weaviate client
                client = weaviate.Client(url="http://localhost:8080")
                
                # Try different class naming patterns
                class_names = [
                    f"RAG_{self.database_name.replace(' ', '_').lower()}_Temp",
                    f"RAG_{self.database_name.replace(' ', '_').lower()}_{self.database_id}",
                    f"Rag{self.database_name.replace(' ', '')}"
                ]
                
                for class_name in class_names:
                    try:
                        # Query Weaviate
                        results = (
                            client.query.get(class_name, ["content", "file_name", "chunk_id"])
                            .with_near_vector({
                                "vector": query_embedding
                            })
                            .with_limit(max_docs)
                            .do()
                        )
                        
                        # Convert results to document format
                        documents = []
                        for item in results.get('data', {}).get('Get', {}).get(class_name, []):
                            content = item.get('content', '')
                            metadata = {
                                'file_name': item.get('file_name', ''),
                                'chunk_id': item.get('chunk_id', '')
                            }
                            documents.append(Document(
                                page_content=content,
                                metadata=metadata
                            ))
                            
                        return documents
                        
                    except Exception as e:
                        print(f"Error with Weaviate class {class_name}: {str(e)}")
                        continue
                
                print("? No suitable Weaviate class found")
                return []
                        
            elif self.vector_db == "faiss":
                import numpy as np
                import faiss
                
                try:
                    # Convert query embedding to the right format
                    query_vector = np.array([query_embedding]).astype('float32')
                    
                    # Search in the index
                    D, I = self.vector_store.search(query_vector, max_docs)
                    
                    # Retrieve documents based on indices
                    documents = []
                    for i in I[0]:
                        if i < len(self.stored_documents):
                            documents.append(self.stored_documents[i])
                    
                    return documents
                    
                except Exception as e:
                    print(f"FAISS error: {str(e)}")
                    return []
                    
            else:
                print(f"Unsupported vector store type: {self.vector_db}")
                return []
                    
        except Exception as e:
            print(f"Error retrieving context: {str(e)}")
            import traceback
            traceback.print_exc()
            return []

    def load_vector_store(self):
        """Load an existing vector store based on the database configuration"""
        try:
            if self.vector_db == "chromadb":
                import chromadb
                from sentence_transformers import SentenceTransformer
                
                # Create persistent client
                client = chromadb.PersistentClient(path="./chroma_db")

                # existing_collections = client.list_collections()
                # print("Existing ChromaDB collections:")
                # for collection in existing_collections:
                #     print(f"- {collection.name}")
            
                
                # Try different naming patterns
                collection_names = [
                    f"rag_{self.database_name.lower()}_Temp",
                    f"rag-{self.database_name.lower()}-{self.database_id or 'temp'}",
                    f"rag_{self.database_name.lower()}_{self.database_id}"
                ]
                
                for collection_name in collection_names:
                    try:
                        print(f"Attempting to load ChromaDB collection: {collection_name}")
                        self.vector_store = client.get_collection(name=collection_name)
                        print(f"? Successfully loaded ChromaDB collection: {collection_name}")
                        return True
                    except ValueError as e:
                        print(f"Collection {collection_name} not found: {str(e)}")
                        continue
                
                # If we get here, we couldn't find any matching collection
                print("? Failed to find any matching ChromaDB collection")
                return False
                        
            elif self.vector_db == "pinecone":
                from pinecone import Pinecone
                
                # Initialize Pinecone client
                pc = Pinecone(api_key="pcsk_2sqpR7_FY6XeaGrqY1NikHysefnoj37anCK9fWMZ5rrxPzW3HU5xWUPVgJZSep9sYpdsCw")
                
                # Try different naming patterns
                index_names = [
                    # f"rag-{self.database_name.lower()}-Temp",
                    # f"rag_{self.database_name.lower()}_{self.database_id}",
                    # f"rag-{self.database_name.lower()}-{self.database_id}",
                    "pineenv"
                ]
                
                # List available indexes for debugging
                available_indexes = pc.list_indexes().names()
                print(f"Available Pinecone indexes: {available_indexes}")
                
                for index_name in index_names:
                    try:
                        print(f"Attempting to load Pinecone index: {index_name}")
                        # Check if index exists in available indexes
                        matching_indexes = [idx for idx in available_indexes if idx.startswith(index_name)]
                        
                        if matching_indexes:
                            # Get the most recently created index
                            actual_index_name = sorted(matching_indexes)[-1]
                            self.vector_store = pc.Index(actual_index_name)
                            print(f"? Successfully loaded Pinecone index: {actual_index_name}")
                            return True
                        else:
                            print(f"No Pinecone index found matching pattern: {index_name}")
                            continue
                    except Exception as e:
                        print(f"Error connecting to Pinecone index {index_name}: {str(e)}")
                        continue
                
                # If we get here, we couldn't find any matching index
                print("? Failed to find any matching Pinecone index")
                return False
                    
            elif self.vector_db == "weaviate":
                import weaviate
                
                # Initialize Weaviate client
                client = weaviate.Client(url="http://localhost:8080")
                
                # Try different naming patterns
                class_names = [
                    f"RAG_{self.database_name.replace(' ', '_').lower()}_Temp",
                    f"RAG_{self.database_name.replace(' ', '_').lower()}_{self.database_id}",
                    f"Rag{self.database_name.replace(' ', '')}"  # Another possible pattern
                ]
                
                # List available classes for debugging
                schema = client.schema.get()
                available_classes = [cls['class'] for cls in schema['classes']] if 'classes' in schema else []
                print(f"Available Weaviate classes: {available_classes}")
                
                for class_name in class_names:
                    try:
                        print(f"Checking for Weaviate class: {class_name}")
                        if client.schema.exists(class_name):
                            self.vector_store = client
                            self.class_name = class_name
                            print(f"? Successfully loaded Weaviate class: {class_name}")
                            return True
                        else:
                            print(f"Weaviate class {class_name} not found")
                            continue
                    except Exception as e:
                        print(f"Error checking Weaviate class {class_name}: {str(e)}")
                        continue
                
                # If we get here, we couldn't find any matching class
                print("? Failed to find any matching Weaviate class")
                return False
                    
            elif self.vector_db == "faiss":
                import faiss
                import pickle
                import os
                
                # Try different file naming patterns
                index_paths = [
                    f"./faiss_indexes/rag-{self.database_name.lower()}-Temp.index",
                    f"./faiss_indexes/rag_{self.database_name.lower()}_Temp.index",
                    f"./faiss_indexes/rag_{self.database_name.lower()}_{self.database_id}.index",
                    f"./faiss_indexes/rag-{self.database_name.lower()}-{self.database_id or 'temp'}.index"
                ]
                
                doc_paths = [
                     f"./faiss_indexes/rag-{self.database_name.lower()}-Temp_docs.pkl",
                    f"./faiss_indexes/rag_{self.database_name.lower()}_Temp_docs.pkl",
                    f"./faiss_indexes/rag_{self.database_name.lower()}_{self.database_id}_docs.pkl",
                    f"./faiss_indexes/rag-{self.database_name.lower()}-{self.database_id or 'temp'}_docs.pkl"
                ]
                
                # List available files for debugging
                if os.path.exists("./faiss_indexes"):
                    available_files = os.listdir("./faiss_indexes")
                    print(f"Available FAISS files: {available_files}")
                
                for i, index_path in enumerate(index_paths):
                    doc_path = doc_paths[i]
                    try:
                        print(f"Attempting to load FAISS index: {index_path}")
                        if os.path.exists(index_path) and os.path.exists(doc_path):
                            # Load FAISS index and stored documents
                            self.vector_store = faiss.read_index(index_path)
                            with open(doc_path, 'rb') as f:
                                self.stored_documents = pickle.load(f)
                            print(f"? Successfully loaded FAISS index: {index_path}")
                            return True
                        else:
                            print(f"FAISS files not found: {index_path} or {doc_path}")
                            continue
                    except Exception as e:
                        print(f"Error loading FAISS index {index_path}: {str(e)}")
                        continue
                
                # If we get here, we couldn't find any matching files
                print("? Failed to find any matching FAISS index files")
                return False
                
            # Unsupported vector database type    
            print(f"? Unsupported vector database type: {self.vector_db}")
            return False
            
        except Exception as e:
            print(f"? Error loading vector store: {str(e)}")
            traceback.print_exc()
            return False

ModuleNotFoundError: No module named 'app'

In [None]:
# routes/rag.py
from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Body
from sqlalchemy.orm import Session
from typing import List
from ..database.connection import get_db
from ..models.rag import RAGDatabase, RAGFile
from ..schemas.rag import RAGDatabase as RAGDatabaseSchema
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import joinedload
import os
import shutil
import zipfile
from datetime import datetime
from app.RAG.rag_pipeline import RAGPipeline

from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from app.RAG.document_loader import DocumentLoader
from app.RAG.chunker_factory import ChunkerFactory
from app.RAG.embedings_geneartor import EnhancedEmbeddingGenerator
from app.RAG.vector_store import VectorStoreFactory
from app.RAG.vector_search import SearchFactory
from app.RAG.rag_types import StandardRAG, GraphRAG, AdaptiveRAG, RaptorRAG
from app.RAG.prompt_optimizer import PromptOptimizer
from app.RAG.reranking_model import RerankingModel
import chromadb
import weaviate
import faiss
import numpy as np
import pinecone
from sentence_transformers import SentenceTransformer
from langchain_core.documents import Document
from pinecone import Pinecone, ServerlessSpec

import re
import uuid


router = APIRouter(prefix="/rag", tags=["rag"])
os.makedirs("./chroma_db", exist_ok=True)

def process_zip_file(zip_path: str, extract_path: str):
    """Extract and process zip file, returning file information"""
    file_info = []
    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_path)
            for file_name in zip_ref.namelist():
                if not file_name.endswith('/'):  # Skip directories
                    file_path = os.path.join(extract_path, file_name)
                    try:
                        content = ""
                        is_text = False
                        
                        text_extensions = {
                            '.txt', '.md', '.csv', '.json', '.xml', 
                            '.yaml', '.yml', '.html', '.htm', '.css', 
                            '.js', '.py', '.java', '.c', '.cpp', '.h', 
                            '.hpp', '.sh', '.bat', '.ps1', '.log',
                            '.ini', '.conf', '.cfg'
                        }
                        
                        file_extension = os.path.splitext(file_name)[1].lower()
                        
                        if file_extension in text_extensions:
                            try:
                                with open(file_path, 'r', encoding='utf-8') as f:
                                    content = f.read()
                                is_text = True
                            except UnicodeDecodeError:
                                encodings = ['latin-1', 'cp1252', 'iso-8859-1']
                                for encoding in encodings:
                                    try:
                                        with open(file_path, 'r', encoding=encoding) as f:
                                            content = f.read()
                                        is_text = True
                                        break
                                    except UnicodeDecodeError:
                                        continue
                        
                        file_info.append({
                            "file_name": os.path.basename(file_name),
                            "file_extension": file_extension,
                            "file_path": file_path,
                            "file_content": content if is_text else "",
                            "file_size": os.path.getsize(file_path) / 1024  # Convert to KB
                        })
                    except Exception as e:
                        print(f"Error processing file {file_name}: {str(e)}")
                        continue
    except Exception as e:
        raise HTTPException(
            status_code=400,
            detail=f"Error processing zip file: {str(e)}"
        )
    return file_info


@router.post("/create", response_model=RAGDatabaseSchema)
async def create_rag_database(
    dataset: UploadFile = File(...),
    name: str = Form(...),
    rag_type: str = Form(...),
    llm_model: str = Form(...),
    embedding_model: str = Form(...),
    chunking_option: str = Form(...),
    vector_db: str = Form(...),
    search_option: str = Form(...),
    db: Session = Depends(get_db)
):
    # Validate file is a zip
    if not dataset.filename.endswith('.zip'):
        raise HTTPException(
            status_code=400,
            detail="File must be a ZIP archive"
        )
    
    # Create directories if they don't exist
    os.makedirs("rag_zip_uploads", exist_ok=True)
    os.makedirs("rag_extracted_uploads", exist_ok=True)
    
    # Create unique directories for this upload
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    zip_dir = f"rag_zip_uploads/{name}_{timestamp}"
    extract_dir = f"rag_extracted_uploads/{name}_{timestamp}"
    
    try:
        # Create directories for this specific upload
        os.makedirs(zip_dir, exist_ok=True)
        os.makedirs(extract_dir, exist_ok=True)

        # Save zip file
        zip_path = os.path.join(zip_dir, dataset.filename)
        try:
            with open(zip_path, "wb") as buffer:
                shutil.copyfileobj(dataset.file, buffer)
        except Exception as e:
            raise HTTPException(
                status_code=400,
                detail=f"Failed to save uploaded file: {str(e)}"
            )

        # Validate zip file
        try:
            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                if zip_ref.testzip() is not None:
                    raise HTTPException(
                        status_code=400,
                        detail="Corrupted ZIP file"
                    )
        except zipfile.BadZipFile:
            raise HTTPException(
                status_code=400,
                detail="Invalid ZIP file"
            )

        # Initialize RAG Pipeline BEFORE database creation
        pipeline = RAGPipeline(
            rag_type=rag_type,
            llm_model=llm_model,
            embedding_model=embedding_model,
            chunking_option=chunking_option,
            vector_db=vector_db,
            search_option=search_option,
            database_name=name,
            database_id='Temp'
        )
        
        # Process documents first
        processed_documents = pipeline.process_documents(zip_path)
        
        # Check if documents were processed
        if not processed_documents:
            raise HTTPException(
                status_code=400,
                detail="No documents were successfully processed"
            )

        # Create RAG database record AFTER processing documents
        try:
            db_rag = RAGDatabase(
                name=name,
                dataset_path=extract_dir,
                zip_file_path=zip_path,
                rag_type=rag_type,
                llm_model=llm_model,
                embedding_model=embedding_model,
                chunking_option=chunking_option,
                vector_db=vector_db,
                search_option=search_option,
                total_files=len(processed_documents),
                status="Processed"
            )
            db.add(db_rag)
            db.flush()  # This assigns an ID to db_rag

            # Create file records
            for doc in processed_documents:
                db_file = RAGFile(
                    rag_database_id=db_rag.id,
                    file_name=doc.metadata.get("file_name", ""),
                    file_extension=doc.metadata.get("file_type", ""),
                    file_path=doc.metadata.get("source", ""),
                    file_size=doc.metadata.get("file_size", 0),
                    file_content=doc.page_content
                )
                db.add(db_file)

            # Commit all changes
            db.commit()
            db.refresh(db_rag)
            
            return db_rag

        except SQLAlchemyError as e:
            db.rollback()
            raise HTTPException(
                status_code=500,
                detail=f"Database error: {str(e)}"
            )

    except HTTPException:
        # Re-raise HTTP exceptions
        raise
    except Exception as e:
        # Catch and handle any other unexpected errors
        raise HTTPException(
            status_code=500,
            detail=f"Failed to process RAG database: {str(e)}"
        )

@router.get("/list")
async def get_rag_databases(db: Session = Depends(get_db)):
    """Get all RAG databases"""
    try:
        from ..models.rag import RAGDatabase as RAGDatabaseModel
        return db.query(RAGDatabaseModel).all()
    except SQLAlchemyError as e:
        raise HTTPException(
            status_code=500,
            detail=f"Database error: {str(e)}"
        )
        

@router.get("/{rag_id}")
async def get_rag_database(rag_id: int, db: Session = Depends(get_db)):
    """Get a specific RAG database by ID"""
    try:
        db_rag = db.query(RAGDatabase).filter(RAGDatabase.id == rag_id).first()
        if db_rag is None:
            raise HTTPException(
                status_code=404,
                detail="RAG database not found"
            )
        return db_rag
    except SQLAlchemyError as e:
        raise HTTPException(
            status_code=500,
            detail=f"Database error: {str(e)}"
        )

@router.delete("/{rag_id}")
async def delete_rag_database(rag_id: int, db: Session = Depends(get_db)):
    """Delete a RAG database by ID"""
    try:
        db_rag = db.query(RAGDatabase).filter(RAGDatabase.id == rag_id).first()
        if db_rag is None:
            raise HTTPException(
                status_code=404,
                detail="RAG database not found"
            )

        # Delete associated files and directories
        if os.path.exists(db_rag.zip_file_path):
            os.remove(db_rag.zip_file_path)
        if os.path.exists(db_rag.dataset_path):
            shutil.rmtree(db_rag.dataset_path)

        # Delete database record
        db.delete(db_rag)
        db.commit()
        return {"message": "RAG database deleted successfully"}

    except SQLAlchemyError as e:
        db.rollback()
        raise HTTPException(
            status_code=500,
            detail=f"Database error: {str(e)}"
        )
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Failed to delete RAG database: {str(e)}"
        )
    








@router.get("/test/{rag_id}/status")
async def test_vector_store(rag_id: int, db: Session = Depends(get_db)):
    """Test any vector store database and return its status"""
    try:
        # Get RAG database info
        db_rag = db.query(RAGDatabase).filter(RAGDatabase.id == rag_id).first()
        if not db_rag:
            raise HTTPException(
                status_code=404,
                detail="RAG database not found"
            )

        status_info = {}
        print("Name of db", db_rag.name)
        
        if db_rag.vector_db == "chromadb":
            try:
                print("1")
                # Use PersistentClient with specific path
                client = chromadb.PersistentClient(path="./chroma_db")
                print("2")
                # Generate consistent collection name
                # collection_name = f"rag_{db_rag.name.lower()}_{db_rag.id}"
                collection_name = f"rag_{db_rag.name.lower()}_Temp"
                print("3")
                try:
                    # Try to get the collection
                    collection = client.get_collection(name=collection_name)
                    print("4")
                    status_info = {
                        "collection_name": collection.name,
                        "total_documents": collection.count(),
                        "status": "active"
                    }
                
                except ValueError:
                    # Collection not found
                    collections = client.list_collections()
                    status_info = {
                        "status": "not_found",
                        "available_collections": [c.name for c in collections],
                        "error": f"No collection found matching {collection_name}"
                    }
                
                except Exception as e:
                    # Fallback error handling
                    status_info = {
                        "status": "error",
                        "error": str(e),
                        "available_collections": [c.name for c in client.list_collections()]
                    }
            
            except Exception as e:
                status_info = {
                    "status": "error",
                    "error": str(e)
                }

        elif db_rag.vector_db == "pinecone":
            try:
                # Initialize Pinecone client
                pc = Pinecone(api_key="pcsk_7BJ2Bj_DRSQZvLAjq9CdAtcNsXNTB38PsCYbbSC38REHTANgfY2PznaEAY48ReEQBWozni")
                
                # List all available indexes
                # available_indexes = pc.list_indexes().names()
                available_indexes = "pineenv"
                
                # Look for an index that matches our naming pattern
                base_name = f"rag-{db_rag.name.lower()}-temp"
                # matching_indexes = [idx for idx in available_indexes if idx.startswith(base_name)]
                matching_indexes = [idx for idx in available_indexes if idx.startswith(available_indexes)]
                if matching_indexes:
                    # Get the most recently created index (highest number)
                    index_name = sorted(matching_indexes)[-1]
                    index = pc.Index(index_name)
                    index_stats = index.describe_index_stats()
                    
                    status_info = {
                        "collection_name": index_name,
                        "total_vectors": index_stats.get('total_vector_count', 0),
                        "dimension": index_stats.get('dimension', 0),
                        "status": "active",
                        "all_matching_indexes": matching_indexes  # Added for debugging
                    }
                else:
                    status_info = {
                        "status": "not_found",
                        "available_indexes": available_indexes,
                        "error": f"No index found matching pattern {base_name}"
                    }
                
            except Exception as e:
                status_info = {
                    "status": "error",
                    "error": str(e)
                }

        elif db_rag.vector_db == "weaviate":
            try:
                # Initialize Weaviate client
                client = weaviate.Client(url="http://localhost:8080")
                
                # Generate class name
                # class_name = f"RAG_{db_rag.name.replace(' ', '_').lower()}_{db_rag.id}"
                class_name = f"RAG_{db_rag.name.replace(' ', '_').lower()}_Temp"
                
                try:
                    # Check if class exists
                    if client.schema.exists(class_name):
                        # Count objects in the class
                        count = client.query.aggregate(class_name).with_meta_count().do()
                        
                        status_info = {
                            "collection_name": class_name,
                            "total_objects": count['data']['Aggregate'][class_name][0]['meta']['count'],
                            "status": "active"
                        }
                    else:
                        status_info = {
                            "status": "not_found",
                            "error": f"No class found matching {class_name}"
                        }
                
                except Exception as e:
                    status_info = {
                        "status": "error",
                        "error": str(e)
                    }
            
            except Exception as e:
                status_info = {
                    "status": "error",
                    "error": str(e)
                }

        elif db_rag.vector_db == "faiss":
            try:
                # For FAISS, we'll need to load the index file
                import faiss
                
                # Assuming indexes are saved with a specific naming convention
                # index_path = f"./faiss_indexes/rag_{db_rag.name.lower()}_{db_rag.id}.index"
                index_path = f"./faiss_indexes/rag_{db_rag.name.lower()}_Temp.index"
                
                try:
                    # Load the index
                    index = faiss.read_index(index_path)
                    
                    status_info = {
                        "collection_name": f"rag_{db_rag.name.lower()}_{db_rag.id}",
                        "total_vectors": index.ntotal,
                        "dimension": index.d,
                        "status": "active"
                    }
                
                except Exception as e:
                    status_info = {
                        "status": "error",
                        "error": str(e)
                    }
            
            except Exception as e:
                status_info = {
                    "status": "error",
                    "error": str(e)
                }

        return {
            "status": "success",
            "db_info": {
                "id": db_rag.id,
                "name": db_rag.name,
                "vector_db": db_rag.vector_db,
                "created_at": db_rag.created_at,
                "status": db_rag.status
            },
            "store_info": status_info
        }

    except Exception as e:
        raise HTTPException(
            status_code=500, 
            detail=f"Error testing vector store: {str(e)}"
        )



@router.post("/test/{rag_id}/query")
async def test_rag_query(
    rag_id: int,
    query: str = Body(...),
    top_k: int = Body(5),
    db: Session = Depends(get_db)
):
    """Test RAG by running a query against the vector store"""
    try:
        # Get RAG database info
        db_rag = db.query(RAGDatabase).filter(RAGDatabase.id == rag_id).first()
        if not db_rag:
            raise HTTPException(
                status_code=404,
                detail="RAG database not found"
            )

        # Embedding model for query conversion
        embedding_model = SentenceTransformer(db_rag.embedding_model)
        query_embedding = embedding_model.encode(query).tolist()

        # Documents to store results
        documents = []

        if db_rag.vector_db == "chromadb":
            try:
                # Initialize ChromaDB client
                client = chromadb.PersistentClient(path="./chroma_db")
                
                # Generate collection name
                # collection_name = f"rag_{db_rag.name.lower()}_{db_rag.id}"
                collection_name = f"rag_{db_rag.name.lower()}_Temp"
                
                # Get the collection
                collection = client.get_collection(name=collection_name)
                
                # Query the collection
                results = collection.query(
                    query_embeddings=[query_embedding],
                    n_results=top_k
                )
                
                # Extract documents
                documents = results.get('documents', [[]])[0]
            
            except Exception as e:
                raise HTTPException(
                    status_code=500,
                    detail=f"ChromaDB query error: {str(e)}"
                )

        elif db_rag.vector_db == "pinecone":
            try:
                # Initialize Pinecone client
                pc = Pinecone(api_key="pcsk_2sqpR7_FY6XeaGrqY1NikHysefnoj37anCK9fWMZ5rrxPzW3HU5xWUPVgJZSep9sYpdsCw")
                # index_name = f"rag-{db_rag.name.lower()}-{db_rag.id}"
                index_name = f"rag-{db_rag.name.lower()}-temp"
                
                # Get the index
                index = pc.Index(index_name)
                
                # Query Pinecone index
                results = index.query(
                    vector=query_embedding,
                    top_k=top_k,
                    include_metadata=True
                )
                
                # Extract documents from metadata
                documents = [
                    match.get('metadata', {}).get('content', '') 
                    for match in results.get('matches', [])
                ]
            
            except Exception as e:
                raise HTTPException(
                    status_code=500,
                    detail=f"Pinecone query error: {str(e)}"
                )

        elif db_rag.vector_db == "weaviate":
            try:
                # Initialize Weaviate client
                client = weaviate.Client(url="http://localhost:8080")
                
                # Generate class name
                # class_name = f"RAG_{db_rag.name.replace(' ', '_').lower()}_{db_rag.id}"
                class_name = f"RAG_{db_rag.name.replace(' ', '_').lower()}_Temp"
                
                # Perform vector search
                results = (
                    client.query.get(class_name, ["content"])
                    .with_near_vector({
                        "vector": query_embedding
                    })
                    .with_limit(top_k)
                    .do()
                )
                
                # Extract documents
                documents = [
                    item.get('content', '') 
                    for item in results.get('data', {}).get('Get', {}).get(class_name, [])
                ]
            
            except Exception as e:
                raise HTTPException(
                    status_code=500,
                    detail=f"Weaviate query error: {str(e)}"
                )

        elif db_rag.vector_db == "faiss":
            try:
                # Import necessary libraries
                import faiss
                import pickle
                
                # Paths for FAISS index and documents
               
               
                index_path = f"./faiss_indexes/rag_{db_rag.name.lower()}_Temp.index"
                documents_path = f"./faiss_indexes/rag_{db_rag.name.lower()}_Temp_docs.pkl"
                
                # Load FAISS index and stored documents
                index = faiss.read_index(index_path)
                with open(documents_path, 'rb') as f:
                    stored_documents = pickle.load(f)
                
                # Search in FAISS index
                D, I = index.search(
                    np.array(query_embedding).reshape(1, -1).astype('float32'), 
                    top_k
                )
                
                # Extract documents
                documents = [stored_documents[i] for i in I[0]]
            
            except Exception as e:
                raise HTTPException(
                    status_code=500,
                    detail=f"FAISS query error: {str(e)}"
                )

        # If no documents found
        if not documents:
            return {
                "query": query,
                "response": "No relevant documents found",
                "similar_documents": [],
                "metadata": {
                    "vector_db": db_rag.vector_db,
                    "model": db_rag.llm_model,
                    "embedding_model": db_rag.embedding_model
                }
            }

        # Generate response using LLM
        context = "\n".join(documents)
        prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
        
        # Initialize tokenizer and model
        tokenizer = AutoTokenizer.from_pretrained(db_rag.llm_model)
        model = AutoModelForSeq2SeqLM.from_pretrained(db_rag.llm_model)
        generator = pipeline(
            "text2text-generation",
            model=model,
            tokenizer=tokenizer,
            max_length=512
        )
        
        # Generate response
        response = generator(prompt)[0]["generated_text"]

        return {
            "query": query,
            "response": response,
            "similar_documents": documents,
            "metadata": {
                "vector_db": db_rag.vector_db,
                "model": db_rag.llm_model,
                "embedding_model": db_rag.embedding_model,
                "total_chunks_found": len(documents)
            }
        }

    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Error testing query: {str(e)}"
        )





class RAGLoadRequest(BaseModel):
    rag_name: str

@router.post("/test/load")
async def load_rag_database(request: RAGLoadRequest, db: Session = Depends(get_db)):
    """Load a RAG database for use with the chatbot"""
    try:
        # Find the RAG database by name
        db_rag = db.query(RAGDatabase).filter(RAGDatabase.name == request.rag_name).first()
        if not db_rag:
            return {
                "success": False,
                "error": f"RAG database '{request.rag_name}' not found"
            }
            
        # Initialize RAG Pipeline
        pipeline = RAGPipeline(
            rag_type=db_rag.rag_type,
            llm_model=db_rag.llm_model,
            embedding_model=db_rag.embedding_model,
            chunking_option=db_rag.chunking_option,
            vector_db=db_rag.vector_db,
            search_option=db_rag.search_option,
            database_name=db_rag.name,
            database_id=str(db_rag.id),
            load_only=True
        )
        
        # Load the vector store
        success = pipeline.load_vector_store()
        
       
        # Update database status
        db_rag.status = "Loaded"
        db.commit()
            
        return {
            "success": True,
            "message": f"RAG database '{request.rag_name}' loaded successfully",
            "database_info": {
                "id": db_rag.id,
                "name": db_rag.name,
                "vector_db": db_rag.vector_db,
                "status": db_rag.status
            }
        }
    except Exception as e:
        return {
            "success": False,
            "error": f"Error loading RAG database: {str(e)}"
        }