# ***SETUP***

In [None]:
import os
import faiss
import numpy as np
import gc
import time
from fastapi import FastAPI
from fastapi.responses import HTMLResponse, JSONResponse
from pathlib import Path
from dotenv import load_dotenv

# Load environment variables from .env
load_dotenv()
gemini_flash_api_key = os.getenv("FlashAPI")
mongo_uri = os.getenv("MONGO_URI")
index_uri = os.getenv("INDEX_URI")
if not gemini_flash_api_key:
    raise ValueError("❌ Gemini Flash API key (FlashAPI) is missing!")
if not mongo_uri:
    raise ValueError("❌ MongoDB URI (MongoURI) is missing!")
if not index_uri:
    raise ValueError("❌ INDEX_URI for FAISS index cluster is missing!")

# --- Environment variables to mitigate segmentation faults ---
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# --- Setup local project directory (for model cache) ---
project_dir = "./AutoGenRAGMedicalChatbot"
os.makedirs(project_dir, exist_ok=True)
huggingface_cache_dir = os.path.join(project_dir, "huggingface_models")
os.environ["HF_HOME"] = huggingface_cache_dir  # Use this folder for HF cache

# --- Download (or load from cache) the SentenceTransformer model ---
from huggingface_hub import snapshot_download
print("Checking or downloading the all-MiniLM-L6-v2 model from huggingface_hub...")
model_loc = snapshot_download(
    repo_id="sentence-transformers/all-MiniLM-L6-v2",
    cache_dir=os.environ["HF_HOME"],
    local_files_only=False
)
print(f"Model directory: {model_loc}")

from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer(model_loc, device="cpu")

# --- MongoDB Setup ---
from pymongo import MongoClient
# QA client
client = MongoClient(mongo_uri)
db = client["MedicalChatbotDB"]  # Use your chosen database name
qa_collection = db["qa_data"]

# FAISS index client
iclient = MongoClient(index_uri)
idb = iclient["FAISSIndexCluster"]  # Use your chosen database name
index_collection = idb["faiss_index"]

# QA Embedding

In [None]:
# --- Load or Build QA Data in MongoDB ---
print("✅ Checking MongoDB for existing QA data...")
if qa_collection.count_documents({}) == 0:
    print("⚠️ QA data not found in MongoDB. Loading dataset from Hugging Face...")
    from datasets import load_dataset
    dataset = load_dataset("ruslanmv/ai-medical-chatbot", cache_dir=huggingface_cache_dir)
    df = dataset["train"].to_pandas()[["Patient", "Doctor"]]
    # Add an index column "i" to preserve order.
    df["i"] = range(len(df))
    qa_data = df.to_dict("records")
    # Insert in batches (e.g., batches of 1000) to avoid document size limits.
    batch_size = 1000
    for i in range(0, len(qa_data), batch_size):
        qa_collection.insert_many(qa_data[i:i+batch_size])
    print(f"✅ QA data stored in MongoDB. Total entries: {len(qa_data)}")
else:
    print("✅ Loaded existing QA data from MongoDB.")
    # Use an aggregation pipeline with allowDiskUse to sort by "i" without creating an index.
    qa_docs = list(qa_collection.aggregate([
        {"$sort": {"i": 1}},
        {"$project": {"_id": 0}}
    ], allowDiskUse=True))
    qa_data = qa_docs
    print("Total QA entries loaded:", len(qa_data))


# FAISS Index Embedding

In [None]:
# --- Build FAISS Index ---
print("Building a compressed FAISS index (using IVFPQ) from QA data...")
# Compute embeddings for each QA pair by concatenating "Patient" and "Doctor" fields.
texts = [doc.get("Patient", "") + " " + doc.get("Doctor", "") for doc in qa_data]
print("Total texts to embed:", len(texts))

batch_size = 512
embeddings_list = []
for i in range(0, len(texts), batch_size):
    batch = texts[i: i + batch_size]
    batch_embeddings = embedding_model.encode(batch, convert_to_numpy=True).astype(np.float32)
    embeddings_list.append(batch_embeddings)
    print(f"Encoded batch {i} to {i + len(batch)}")

embeddings = np.vstack(embeddings_list)
dim = embeddings.shape[1]
print("Embeddings shape:", embeddings.shape)

# --- Build Compressed FAISS Index using IVFPQ (to reduce storage size) ---
nlist = 100   # number of clusters
m = 8         # number of subquantizers
nbits = 8     # bits per subvector

quantizer = faiss.IndexFlatL2(dim)
index = faiss.IndexIVFPQ(quantizer, dim, nlist, m, nbits)
print("Training the IVFPQ index on embeddings...")
index.train(embeddings)
index.add(embeddings)
print("Compressed FAISS index built. Total vectors:", index.ntotal)

# --- Serialize and Store FAISS Index in GridFS (on separate cluster) ---
print("Serializing FAISS index...")
index_bytes = faiss.serialize_index(index)
# Convert the raw bytes to a native bytes object
index_data = np.frombuffer(index_bytes, dtype='uint8').tobytes()

# Delete any existing FAISS index file in GridFS.
existing_file = fs.find_one({"filename": "faiss_index.bin"})
if existing_file:
    fs.delete(existing_file._id)

file_id = fs.put(index_data, filename="faiss_index.bin")
print("✅ FAISS index stored in GridFS with file_id:", file_id)

del embeddings
gc.collect()
print("✅ Compressed FAISS index stored in MongoDB (separate cluster) successfully!")
