In [5]:
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Tuple
import pickle
import os
import logging


class AdvancedRetrievalSystem:
    """High-performance semantic search with FAISS indexing."""
    
    def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
        self.model = SentenceTransformer(model_name)
        self.index = None
        self.chunks = []
        self.embeddings = None
        self.logger = logging.getLogger(__name__)
        
    def build_index(self, chunks: List[Dict], save_path: str = None):
        """Build FAISS index from document chunks."""
        self.chunks = chunks
        texts = [chunk['text'] for chunk in chunks]
        
        self.logger.info(f"Generating embeddings for {len(texts)} chunks...")
        self.embeddings = self.model.encode(texts, show_progress_bar=True)
        
        # Create FAISS index
        dimension = self.embeddings.shape[1]
        self.index = faiss.IndexFlatIP(dimension)  # Inner product for cosine similarity
        
        # Normalize embeddings for cosine similarity
        normalized_embeddings = self.embeddings / np.linalg.norm(self.embeddings, axis=1, keepdims=True)
        self.index.add(normalized_embeddings.astype('float32'))
        
        self.logger.info(f"Built FAISS index with {self.index.ntotal} vectors")
        
        # Save index if path provided
        if save_path:
            self.save_index(save_path)
    
    def search(self, query: str, top_k: int = 5) -> List[Dict]:
        """Semantic search with confidence scoring."""
        if self.index is None:
            raise ValueError("Index not built. Call build_index() first.")
        
        # Encode query
        query_embedding = self.model.encode([query])
        query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True)
        
        # Search
        scores, indices = self.index.search(query_embedding.astype('float32'), top_k)
        
        results = []
        for score, idx in zip(scores[0], indices[0]):
            if idx < len(self.chunks):  # Valid index
                result = self.chunks[idx].copy()
                result['similarity_score'] = float(score)
                result['confidence'] = self._calculate_confidence(score)
                results.append(result)
        
        return results
    
    def _calculate_confidence(self, similarity_score: float) -> str:
        """Convert similarity score to confidence level."""
        if similarity_score > 0.8:
            return "high"
        elif similarity_score > 0.6:
            return "medium"
        else:
            return "low"
    
    def save_index(self, save_path: str):
        """Save index and metadata."""
        os.makedirs(save_path, exist_ok=True)
        
        # Save FAISS index
        faiss.write_index(self.index, os.path.join(save_path, "faiss_index.bin"))
        
        # Save chunks and embeddings
        with open(os.path.join(save_path, "chunks.pkl"), 'wb') as f:
            pickle.dump(self.chunks, f)
        
        with open(os.path.join(save_path, "embeddings.pkl"), 'wb') as f:
            pickle.dump(self.embeddings, f)
    
    def load_index(self, save_path: str):
        """Load pre-built index."""
        # Load FAISS index
        self.index = faiss.read_index(os.path.join(save_path, "faiss_index.bin"))
        
        # Load chunks and embeddings
        with open(os.path.join(save_path, "chunks.pkl"), 'rb') as f:
            self.chunks = pickle.load(f)
        
        with open(os.path.join(save_path, "embeddings.pkl"), 'rb') as f:
            self.embeddings = pickle.load(f)


class StreamlitRetrieval:
    def __init__(self):
        self.retrieval = AdvancedRetrievalSystem()  # Changed from RetrievalSystem to AdvancedRetrievalSystem
        
    def search_for_streamlit(self, query, company, top_k=5):
        try:
            results = self.retrieval.search(query, top_k=top_k)
            
            formatted_results = []
            for result in results:
                formatted_results.append({
                    "text": result.get("text", "")[:200] + "...",
                    "score": result.get("similarity_score", 0.0),  # Changed from "score" to "similarity_score"
                    "confidence": result.get("confidence", "low"),
                    "company": company
                })
            
            return formatted_results
        except Exception as e:
            return [{"text": f"Error: {str(e)}", "score": 0.0, "company": company}]


# Initialize the retrieval system
streamlit_retrieval = StreamlitRetrieval()
print("✅ Advanced Retrieval System initialized for Streamlit")


✅ Advanced Retrieval System initialized for Streamlit
