In [None]:
import faiss
import pickle
import numpy as np
from sentence_transformers import SentenceTransformer
from langchain_community.llms import Ollama
from langchain_core.messages import HumanMessage, SystemMessage
import time

# --- OPTIMIZED CONFIG ---
FAISS_INDEX_PATH = r"C:\Users\kau75421\LLMprojects\Marketing_campaginer\Recommender_Systems\Notebooks\faiss.index"
CHUNKS_MAPPING_PATH = r"C:\Users\kau75421\LLMprojects\Marketing_campaginer\Recommender_Systems\Notebooks\faiss_data.pkl"
EMBEDDING_MODEL = 'all-MiniLM-L6-v2'
NORMALIZE = True
TOP_K = 3                             # Reduced from 5
SCORE_THRESHOLD = 0.2                 # Lowered for more matches
OLLAMA_MODEL = "mistral:7b-instruct-q4_0"

# --- GLOBAL VARIABLES FOR CACHING ---
embedder = None
index = None
chunk_mapping = None
llm = None

# --- OPTIMIZED UTILS ---
def load_models_once():
    """Load all models once at startup"""
    global embedder, index, chunk_mapping, llm
    
    if embedder is None:
        print("🔄 Loading embedding model...")
        embedder = SentenceTransformer(EMBEDDING_MODEL)
        
    if index is None or chunk_mapping is None:
        print("🔄 Loading FAISS index...")
        index = faiss.read_index(FAISS_INDEX_PATH)
        with open(CHUNKS_MAPPING_PATH, "rb") as f:
            chunk_mapping = pickle.load(f)
            
    if llm is None:
        print("🔄 Loading LLM...")
        llm = Ollama(
            model=OLLAMA_MODEL, 
            temperature=0.0,
            # Add these for faster inference
            num_predict=200,  # Limit response length
            top_k=10,         # Reduce sampling space
            top_p=0.9
        )

def embed_query_fast(query):
    """Faster query embedding"""
    return embedder.encode([query], normalize_embeddings=NORMALIZE, show_progress_bar=False).astype("float32")

def retrieve_and_filter(query_embedding, k=TOP_K):
    """Combined retrieval and filtering"""
    distances, indices = index.search(query_embedding, k)
    
    # Quick filtering and deduplication in one pass
    seen = set()
    filtered_chunks = []
    
    for i, (idx, score) in enumerate(zip(indices[0], distances[0])):
        if score >= SCORE_THRESHOLD:
            chunk = chunk_mapping[idx]
            if chunk not in seen:
                seen.add(chunk)
                filtered_chunks.append(chunk)
                
    return filtered_chunks

def build_concise_prompt(chunks, user_query):
    """Shorter prompt for faster processing"""
    # Take only top 2 chunks and truncate them
    context_chunks = []
    for chunk in chunks[:2]:
        # Truncate long chunks
        truncated = chunk[:300] + "..." if len(chunk) > 300 else chunk
        context_chunks.append(truncated)
    
    context = "\n---\n".join(context_chunks)
    
    return f"""Based on this product info, answer briefly:

{context}

Q: {user_query}
A:"""

def get_llm_response_fast(prompt):
    """Faster LLM response handling"""
    try:
        # Use simpler message format
        response = llm.invoke(prompt)
        
        if isinstance(response, str):
            return response.strip()
        elif hasattr(response, 'content'):
            return response.content.strip()
        else:
            return str(response).strip()
            
    except Exception as e:
        return f"⚠️ Error getting response: {str(e)}"

# --- OPTIMIZED MAIN ---
def main():
    print("🚀 Starting optimized RAG system...")
    
    # Load everything once at startup
    load_models_once()
    print("✅ All models loaded! Ready for queries.\n")

    while True:
        user_query = input("🔍 Ask about products (or 'exit'): ").strip()
        if user_query.lower() in ['exit', 'quit', 'q']:
            break

        if len(user_query) < 3:
            print("⚠️ Please enter a longer query.")
            continue

        start_time = time.time()

        # Step 1: Fast embedding
        query_embedding = embed_query_fast(user_query)
        embed_time = time.time()

        # Step 2: Fast retrieval and filtering
        chunks = retrieve_and_filter(query_embedding)
        retrieval_time = time.time()

        if not chunks:
            print("⚠️ No relevant products found. Try a different query.")
            continue

        # Step 3: Build concise prompt
        prompt = build_concise_prompt(chunks, user_query)
        
        # Optional: Show retrieved info (comment out for even faster performance)
        print(f"\n📄 Found {len(chunks)} relevant chunks")
        print(chunks)
        # Step 4: Fast LLM response
        print("\n💬 Answer:")
        response = get_llm_response_fast(prompt)
        print(response)
        
        llm_time = time.time()

        # Timing breakdown
        total_time = llm_time - start_time
        print(f"\n⏱️ Timing: Embed: {embed_time-start_time:.2f}s | "
              f"Retrieve: {retrieval_time-embed_time:.2f}s | "
              f"LLM: {llm_time-retrieval_time:.2f}s | "
              f"Total: {total_time:.2f}s")

if __name__ == "__main__":
    main()

🚀 Starting optimized RAG system...
🔄 Loading embedding model...


  llm = Ollama(


🔄 Loading FAISS index...
🔄 Loading LLM...
✅ All models loaded! Ready for queries.


📄 Found 3 relevant chunks
['description: these high-quality, traditional football pants are made of 82% polyester/18% spandex heavyweight fabric with "quick recovery" elasticity and a matte finish. cover stitched for optimum durability. hassle-free, factory-installed pads come already sewn in to reduce prep time so athletes can get ready for games and practice quickly. a built-in, full-length covered web belt takes the fuss out of securing the pants. the 2.5 elastic waistband is lined with gripper strips to keep your jersey tucked. the stretch fit fabric fits tightly and uses special 4-way stretch fabrics to expand over the body to allow for extra comfort and range of movement. available in adult sizes s-3xl and youth sizes xxs-2xl and husky in 13 colors.', 'sizes: ["for waist 24\\"-28\\"","for waist 29\\"-34\\"","for waist 35\\"-40\\"","for waist 41\\"-46\\"","for waist 47\\"-50\\"","for waist 51\\"-55