In [None]:
"""
1) Streams interactions from MongoDB in batches
2) Builds a sparse book×user CSR matrix
3) Runs TruncatedSVD to get low-dim embeddings
4) Builds a Faiss index for cosine similarity
5) For each book, finds top-K similar books and writes them
   into MongoDB under the `similar_books` field.
"""

import os
import pickle
import numpy as np
import faiss
from pymongo import MongoClient, UpdateOne
from scipy.sparse import coo_matrix
from sklearn.decomposition import TruncatedSVD

# ─────────────────────────────── CONFIG ────────────────────────────────────────
MONGO_URI      = os.getenv("MONGO_URI", "mongodb://127.0.0.1:27017")
DB_NAME        = "book_database"
INTER_COLL     = "goodreads_interactions"
BOOKS_COLL     = "books"
BATCH_SIZE     = 2_000_000           # docs per batch when streaming
EMBED_DIM      = 50                  # SVD embedding size
FAISS_INDEX_TYPE = "Flat"            # "Flat" or "HNSW"
TOP_K          = 10                  # how many neighbors to store
BULK_BATCH     = 5_000               # mongo bulk-write batch size

# ──────────────────── STEP 1: STREAM & BUILD SPARSE MATRIX ────────────────────
def stream_and_build_matrix():
    client = MongoClient(MONGO_URI)
    coll   = client[DB_NAME][INTER_COLL]

    # 1a) discover all unique users & books
    users = set(); books = set()
    for doc in coll.find({}, {"user_id":1,"book_id":1}, batch_size=BATCH_SIZE):
        users.add(doc["user_id"])
        books.add(doc["book_id"])
    user2idx = {u:i for i,u in enumerate(sorted(users))}
    book2idx = {b:i for i,b in enumerate(sorted(books))}
    idx2book = {i:b for b,i in book2idx.items()}
    n_users, n_books = len(user2idx), len(book2idx)
    print(f"→ Found {n_books:,} books & {n_users:,} users.")

    # 1b) stream again into COO arrays
    rows, cols, data = [], [], []
    count = 0
    for doc in coll.find({}, {"user_id":1,"book_id":1,"rating":1}, batch_size=BATCH_SIZE):
        u = doc["user_id"]; b = doc["book_id"]
        rows.append(book2idx[b])
        cols.append(user2idx[u])
        data.append(float(doc.get("rating", 1.0)))
        count += 1
        if count % BATCH_SIZE == 0:
            print(f"  streamed {count:,} interactions…")
    client.close()
    print(f"→ Total interactions: {count:,}")

    mat = coo_matrix((data, (rows, cols)), shape=(n_books, n_users)).tocsr()
    return mat, idx2book

# ─────────────────── STEP 2: TRUNCATED SVD FOR EMBEDDINGS ─────────────────────
def compute_embeddings(mat, n_components=EMBED_DIM):
    print(f"→ Running TruncatedSVD to {n_components} dims…")
    svd = TruncatedSVD(n_components=n_components, algorithm="arpack", random_state=42)
    emb = svd.fit_transform(mat)  # shape (n_books, n_components)
    print(f"   explained variance: {svd.explained_variance_ratio_.sum():.4f}")
    return emb

# ────────────────── STEP 3: BUILD FAISS COSINE INDEX ───────────────────────────
def build_faiss_index(emb, index_type=FAISS_INDEX_TYPE):
    d = emb.shape[1]
    faiss.normalize_L2(emb)
    if index_type == "Flat":
        idx = faiss.IndexFlatIP(d)
    elif index_type == "HNSW":
        idx = faiss.IndexHNSWFlat(d, 32)
    else:
        raise ValueError("Unknown index type")
    idx.add(emb)
    print(f"→ Faiss index built (ntotal={idx.ntotal})")
    return idx

# ──────────────── STEP 4: PRECOMPUTE & BULK-WRITE SIMILARS ─────────────────────
def precompute_and_write(index, emb, idx2book, top_k=TOP_K):
    book2idx = {b:i for i,b in idx2book.items()}
    client = MongoClient(MONGO_URI)
    coll   = client[DB_NAME][BOOKS_COLL]
    ops    = []

    def get_top_k(bid):
        i = book2idx[bid]
        D, I = index.search(emb[i : i+1], top_k+1)
        sims = []
        for dist,idx in zip(D[0], I[0]):
            if idx == i: continue
            sims.append(idx2book[idx])
            if len(sims) >= top_k: break
        return sims

    print("→ Computing & queuing updates…")
    for i, bid in idx2book.items():
        sim_ids = get_top_k(bid)
        ops.append(UpdateOne({"book_id": bid},
                             {"$set": {"similar_books": sim_ids}}))
        if len(ops) >= BULK_BATCH:
            coll.bulk_write(ops)
            print(f"   wrote {len(ops)} docs…")
            ops = []
    if ops:
        coll.bulk_write(ops)
        print(f"   wrote final {len(ops)} docs.")
    client.close()
    print("→ All books updated with `similar_books`.")

# ───────────────────────────── MAIN ────────────────────────────────────────────
if __name__ == "__main__":
    # 1) build matrix
    mat, idx2book = stream_and_build_matrix()

    # 2) compute embeddings
    embeddings = compute_embeddings(mat)

    # 3) build Faiss index
    faiss_idx = build_faiss_index(embeddings)

    # 4) precompute & write
    precompute_and_write(faiss_idx, embeddings, idx2book)
