In [None]:
#!/usr/bin/env python3
"""
semantic_search_engine.py

A modular semantic search engine for a Reddit life-advice corpus.

This version is configured for interactive use in notebooks.
It loads a dataset from a hardcoded path and builds the FAISS index.
"""
from __future__ import annotations

import argparse
import json
import os
import pickle
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer

try:
    from tqdm import tqdm
except ImportError:
    def tqdm(iterable=None, *args, **kwargs):
        return iterable if iterable is not None else range(0)


class SemanticSearchEngine:
    def __init__(self,
                 model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
                 index_dir: str | Path = "index",
                 use_gpu: bool = False):
        self.index_dir = Path(index_dir)
        self.index_dir.mkdir(parents=True, exist_ok=True)
        self.index_path = self.index_dir / "faiss.index"
        self.meta_path = self.index_dir / "metadata.pkl"

        self.model_name = model_name
        self.model = SentenceTransformer(model_name)
        if use_gpu:
            self.model = self.model.to("cuda")

        self.index: faiss.Index | None = None
        self.metadata: List[Dict] = []

    @staticmethod
    def _normalise(emb: np.ndarray) -> np.ndarray:
        norm = np.linalg.norm(emb, axis=1, keepdims=True)
        return emb / np.maximum(norm, 1e-12)

    def build(self, docs: List[str], metas: List[Dict], hnsw_m: int = 32):
        if len(docs) != len(metas):
            raise ValueError("docs and metas must have identical length")

        print(f"Encoding {len(docs)} documents with {self.model_name} …")
        embeddings = self.model.encode(docs, batch_size=128, show_progress_bar=True, convert_to_numpy=True)
        embeddings = self._normalise(embeddings.astype('float32'))

        dim = embeddings.shape[1]
        self.index = faiss.IndexHNSWFlat(dim, hnsw_m, faiss.METRIC_INNER_PRODUCT)
        self.index.hnsw.efConstruction = 200

        print("Adding vectors to index …")
        self.index.add(embeddings)
        self.metadata = metas
        self.save()
        print(f"✓ Built index with {self.index.ntotal} vectors → {self.index_path.resolve()}")


    def save(self):
        if self.index is None:
            raise RuntimeError("Index not initialised; call build() first")
        faiss.write_index(self.index, str(self.index_path))
        with open(self.meta_path, "wb") as f:
            pickle.dump(self.metadata, f)

    def load(self):
        if self.index is None:
            print("Loading index into memory …")
            self.index = faiss.read_index(str(self.index_path))
        if not self.metadata:
            with open(self.meta_path, "rb") as f:
                self.metadata = pickle.load(f)

    def search(self, query: str, top_k: int = 5) -> List[Dict]:
        self.load()
        q_emb = self.model.encode([query], convert_to_numpy=True)
        q_emb = self._normalise(q_emb.astype('float32'))

        scores, idxs = self.index.search(q_emb, top_k)
        results = []
        for score, idx in zip(scores[0], idxs[0]):
            item = self.metadata[idx].copy()
            item["score"] = float(score)
            results.append(item)
        return results

    def rag_answer(self, query: str, context_k: int = 5, model: str = "gpt-3.5-turbo-0125",
                   openai_api_key: str | None = None) -> str:
        import textwrap
        import openai

        openai.api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
        if not openai.api_key:
            raise EnvironmentError("OPENAI_API_KEY missing; pass via env var or parameter")

        contexts = self.search(query, top_k=context_k)
        concatenated = "\n\n".join(
            [f"Title: {c['title']}\nURL: {c['url']}" for c in contexts]
        )

        sys_prompt = "You are a friendly assistant who provides concise, actionable life advice."
        user_prompt = (
            f"Based on the Reddit advice below, answer the QUESTION in 3–5 bullet points.\n\n"
            f"REDDIT ADVICE:\n{concatenated}\n\nQUESTION: {query}"
        )

        response = openai.ChatCompletion.create(
            model=model,
            messages=[{"role": "system", "content": sys_prompt},
                      {"role": "user", "content": user_prompt}],
            temperature=0.7,
        )
        return textwrap.fill(response.choices[0].message.content.strip(), 100)


def parse_reddit_jsonl(path: str) -> tuple[List[str], List[Dict]]:
    docs: List[str] = []
    metas: List[Dict] = []

    with open(path, "r", encoding="utf-8") as fh:
        for line in tqdm(fh, desc="Reading JSONL"):
            item = json.loads(line)
            title = item.get("title", "").strip()
            body = item.get("body", "").strip()
            comments = " ".join(item.get("top_comments", []))
            full_text = " ".join([title, body, comments]).strip()

            docs.append(full_text)
            metas.append({
                "title": title or full_text[:60] + "…",
                "url": item.get("url"),
                "post_score": item.get("score", 0),
            })
    return docs, metas


# Automatically build the index when run as a script (for notebook/test use)
if __name__ == "__main__":
    dataset_path = "reddit_wisdom_data.jsonl"  # <<<<<< SET YOUR PATH HERE

    engine = SemanticSearchEngine(index_dir="index")
    docs, metas = parse_reddit_jsonl(dataset_path)
    engine.build(docs, metas)
    print("Search engine index built and saved.")
