In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Train text-embedding based recommendation + clustering on the cleaned dataset.

Inputs:
  --data backend/data/cleaned_furniture.csv

Artifacts (default --out backend/storage):
  - embeddings.npy               (float32, shape [N, D])
  - faiss_index.bin             (FAISS cosine index)
  - metadata.json               (list of product dicts aligned with embeddings)
  - cluster_labels.csv          (uniq_id, cluster_label)
  - training_report.json        (summary: rows, dims, timing, silhouette, etc.)

Usage:
  python train_models.py --data backend/data/cleaned_furniture.csv --out backend/storage --clusters 12
"""

import argparse
import json
import os
import sys
import time
from typing import List, Dict

import numpy as np
import pandas as pd

# Embeddings
from sentence_transformers import SentenceTransformer

# Clustering
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

# Vector index
try:
    import faiss  # faiss-cpu
except Exception as e:
    print("[train] ERROR: faiss not installed. Install faiss-cpu", flush=True)
    raise

# ---------------------------
# Helpers
# ---------------------------

def log(msg: str):
    print(f"[train] {msg}", flush=True)

def read_cleaned_csv(path: str) -> pd.DataFrame:
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing dataset: {path}")
    df = pd.read_csv(path)
    # Ensure expected normalized columns exist (from your cleaner script)
    expected = [
        "uniq_id", "title_norm", "brand_norm", "description_norm", "price_num",
        "categories_norm", "primary_image", "material_std", "color_std"
    ]
    for c in expected:
        if c not in df.columns:
            log(f"WARNING: column '{c}' not found; creating placeholder.")
            df[c] = np.nan
    # Minimal cleanups
    df["title_norm"] = df["title_norm"].fillna("Unknown Title").astype(str)
    df["brand_norm"] = df["brand_norm"].fillna("Unknown").astype(str)
    df["description_norm"] = df["description_norm"].fillna("No description available.").astype(str)
    df["material_std"] = df["material_std"].fillna("unknown").astype(str)
    df["color_std"] = df["color_std"].fillna("unknown").astype(str)
    # categories_norm might be a list serialized as string -> try to parse light
    def parse_cats(x):
        if pd.isna(x): return []
        s = str(x).strip()
        if s.startswith("[") and s.endswith("]"):
            try:
                import ast
                v = ast.literal_eval(s)
                if isinstance(v, list): return [str(t) for t in v]
            except Exception:
                pass
        # fallback: split on commas
        return [t.strip() for t in s.split(",") if t.strip()]
    df["categories_list"] = df["categories_norm"].apply(parse_cats)
    return df

def build_corpus_row(r: pd.Series) -> str:
    """
    Compose a compact text for embedding that captures key product signals.
    """
    cats = " > ".join(r.get("categories_list", [])[:3])
    parts = [
        r.get("title_norm", ""),
        f"Brand: {r.get('brand_norm', '')}",
        f"Material: {r.get('material_std', '')}",
        f"Color: {r.get('color_std', '')}",
        f"Categories: {cats}" if cats else "",
        r.get("description_norm", "")
    ]
    return " | ".join([p for p in parts if p])

def embed_texts(model: SentenceTransformer, texts: List[str], batch_size: int = 128) -> np.ndarray:
    embs = []
    for i in range(0, len(texts), batch_size):
        chunk = texts[i:i+batch_size]
        embs.append(model.encode(chunk, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=False))
    X = np.vstack(embs).astype("float32")
    return X

def l2_normalize(vectors: np.ndarray) -> np.ndarray:
    norms = np.linalg.norm(vectors, axis=1, keepdims=True) + 1e-12
    return vectors / norms

def build_faiss_cosine_index(vectors_unit: np.ndarray):
    """
    Cosine similarity with FAISS: use IndexFlatIP on L2-normalized vectors.
    """
    d = vectors_unit.shape[1]
    index = faiss.IndexFlatIP(d)
    index.add(vectors_unit)
    return index

def safe_json_dump(obj, path: str):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)

# ---------------------------
# Main training
# ---------------------------

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data", required=True, help="Path to cleaned CSV (e.g., backend/data/cleaned_furniture.csv)")
    ap.add_argument("--out", default="backend/storage", help="Output directory for artifacts")
    ap.add_argument("--clusters", type=int, default=12, help="KMeans number of clusters")
    ap.add_argument("--model", default="sentence-transformers/all-MiniLM-L6-v2", help="HF sentence-transformer")
    ap.add_argument("--eval_queries", nargs="*", default=[
        "minimalist wooden chair under 5000",
        "compact study table for small room",
        "king size bed modern design",
        "ergonomic office chair with lumbar support",
        "glass dining table for 4"
    ], help="Sample natural language queries for quick eval")
    args = ap.parse_args()

    os.makedirs(args.out, exist_ok=True)

    t0 = time.time()
    log(f"Loading cleaned dataset: {args.data}")
    df = read_cleaned_csv(args.data)

    # Compose corpus
    log("Composing text corpus for embeddings…")
    corpus = [build_corpus_row(r) for _, r in df.iterrows()]

    # Load embedder
    log(f"Loading embedder: {args.model}")
    emb_model = SentenceTransformer(args.model)

    # Compute embeddings
    log("Embedding texts… (this may take a bit)")
    X = embed_texts(emb_model, corpus, batch_size=128)   # [N, D]
    d = X.shape[1]
    log(f"Embeddings shape: {X.shape}")

    # Normalize for cosine similarity
    X_unit = l2_normalize(X)

    # Train KMeans
    k = min(args.clusters, len(df)) if len(df) > 1 else 1
    if k < 2:
        log("Not enough rows for clustering; skipping KMeans.")
        labels = np.zeros((len(df),), dtype=int)
        sil = None
    else:
        log(f"Training KMeans with k={k}…")
        km = KMeans(n_clusters=k, random_state=42, n_init="auto")
        labels = km.fit_predict(X)
        # Silhouette score requires >= 2 clusters and > labels per cluster ideally
        try:
            sil = float(silhouette_score(X, labels))
        except Exception:
            sil = None
        log(f"KMeans trained. Silhouette: {sil}")

    # Build FAISS cosine index
    log("Building FAISS cosine index…")
    faiss_index = build_faiss_cosine_index(X_unit)

    # Save artifacts
    emb_path = os.path.join(args.out, "embeddings.npy")
    faiss_path = os.path.join(args.out, "faiss_index.bin")
    meta_path = os.path.join(args.out, "metadata.json")
    clusters_path = os.path.join(args.out, "cluster_labels.csv")
    report_path = os.path.join(args.out, "training_report.json")

    log(f"Saving embeddings -> {emb_path}")
    np.save(emb_path, X.astype("float32"))

    log(f"Saving FAISS index -> {faiss_path}")
    faiss.write_index(faiss_index, faiss_path)

    # Build and save metadata aligned with rows
    log("Saving metadata…")
    meta: List[Dict] = []
    for i, r in df.reset_index(drop=True).iterrows():
        meta.append({
            "row": int(i),
            "uniq_id": str(r.get("uniq_id", "")),
            "title": str(r.get("title_norm", "")),
            "brand": str(r.get("brand_norm", "")),
            "price": None if pd.isna(r.get("price_num")) else float(r.get("price_num")),
            "categories": r.get("categories_list", []),
            "image_url": None if pd.isna(r.get("primary_image")) else str(r.get("primary_image")),
            "material": str(r.get("material_std", "")),
            "color": str(r.get("color_std", "")),
        })
    safe_json_dump(meta, meta_path)

    # Save cluster labels
    log(f"Saving cluster labels -> {clusters_path}")
    pd.DataFrame({"uniq_id": df["uniq_id"], "cluster_label": labels}).to_csv(clusters_path, index=False)

    # Report
    t1 = time.time()
    report = {
        "rows": int(len(df)),
        "embed_dim": int(d),
        "clusters": int(k),
        "silhouette": sil,
        "time_sec": round(t1 - t0, 2),
        "artifacts": {
            "embeddings": emb_path,
            "faiss_index": faiss_path,
            "metadata": meta_path,
            "cluster_labels": clusters_path
        },
        "model": args.model,
    }
    safe_json_dump(report, report_path)
    log(f"Training complete in {report['time_sec']}s")

    # ---------------------------
    # Quick eval with sample queries (cosine)
    # ---------------------------
    def search_top_k(query: str, top_k: int = 5):
        q = emb_model.encode([query], convert_to_numpy=True, normalize_embeddings=False).astype("float32")
        q = l2_normalize(q)
        scores, idxs = faiss_index.search(q, top_k)  # inner product on unit vecs = cosine
        idxs = idxs[0].tolist()
        scores = scores[0].tolist()
        results = []
        for s, j in zip(scores, idxs):
            if j < 0:  # FAISS returns -1 if fewer than top_k
                continue
            results.append({
                "score": float(s),
                "uniq_id": meta[j]["uniq_id"],
                "title": meta[j]["title"],
                "brand": meta[j]["brand"],
                "price": meta[j]["price"],
                "categories": meta[j]["categories"],
            })
        return results

    log("----- Quick Eval (sample queries) -----")
    for q in args.eval_queries:
        hits = search_top_k(q, top_k=5)
        log(f"Q: {q}")
        for h in hits:
            log(f"  • {h['title']}  | {h['brand']}  | ₹{h['price']}  | score={h['score']:.3f}")
    log("---------------------------------------")

    # Provide a tiny interactive CLI if run in a TTY
    if sys.stdin.isatty():
        log("Enter a query to search (or blank to exit):")
        try:
            while True:
                user_q = input("> ").strip()
                if not user_q:
                    break
                for h in search_top_k(user_q, top_k=5):
                    print(f"  - {h['title']} | {h['brand']} | ₹{h['price']} | score={h['score']:.3f}")
        except (EOFError, KeyboardInterrupt):
            pass


if __name__ == "__main__":
    main()
