In [2]:
%pip install annoy

StatementMeta(, 18184d11-bac4-4d8a-a7b4-898a7859c4e4, 8, Finished, Available, Finished)

Collecting annoy
  Downloading annoy-1.17.3.tar.gz (647 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m647.5/647.5 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25l- \ | done
[?25hBuilding wheels for collected packages: annoy
  Building wheel for annoy (setup.py) ... [?25l- \ | / - \ | done
[?25h  Created wheel for annoy: filename=annoy-1.17.3-cp311-cp311-linux_x86_64.whl size=77476 sha256=244679487111355cc08c5559cc9c6457f1397ff6905fdf84dd981c6fc41ca34b
  Stored in directory: /home/trusted-service-user/.cache/pip/wheels/33/e5/58/0a3e34b92bedf09b4c57e37a63ff395ade6f6c1099ba59877c
Successfully built annoy
Installing collected packages: annoy
Successfully installed annoy-1.17.3

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m2

In [3]:
# ============================================
# CELL 1 ‚Äî IMPORTS & LOAD MODEL
# ============================================
import joblib
import numpy as np
from pyspark.sql.functions import col
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from annoy import AnnoyIndex

# Load catalog + model + tables
CATALOG_TABLE = "gold_ml_catalog"
USER_EMB_TABLE = "gold_two_tower_user_emb"
ITEM_EMB_TABLE = "gold_two_tower_item_emb"
GOLD_LABEL_TABLE = "gold_ml_training_set"

catalog_enriched = spark.table(CATALOG_TABLE)

lgbm_model = joblib.load("/lakehouse/default/Files/models/lgbm_re_ranker.pkl")

print("‚úÖ Model & catalog loaded.")


StatementMeta(, 18184d11-bac4-4d8a-a7b4-898a7859c4e4, 10, Finished, Available, Finished)

‚úÖ Model & catalog loaded.


In [4]:
# ============================================
# CELL 2 ‚Äî UNIFIED DISCOVERY (LightGBM scoring)
# Using PREBUILT Annoy Indexes
# ============================================

from annoy import AnnoyIndex
from pyspark.sql.window import Window
from pyspark.sql.functions import col
import pyspark.sql.functions as F
import numpy as np
import joblib

# Tables
ITEM_EMB_TABLE  = "gold_two_tower_item_emb"
USER_EMB_TABLE  = "gold_two_tower_user_emb"
GOLD_LABEL_TABLE = "gold_ml_training_set"

# Arabic filter switch
AVOID_ARABIC = True
arabic_pattern = "[\\u0600-\\u06FF]"

# =====================================================
# LOAD PREBUILT ANNOY INDEXES (Friends + Ocean)
# =====================================================

dim = 256  # embedding size

# ---- Friends index ----
friends_index = AnnoyIndex(dim, 'angular')
friends_index.load("/lakehouse/default/Files/annoy/friends_index.ann")
friends_map = joblib.load("/lakehouse/default/Files/annoy/friends_map.pkl")

# ---- Ocean index ----
ocean_index = AnnoyIndex(dim, 'angular')
ocean_index.load("/lakehouse/default/Files/annoy/ocean_index.ann")
ocean_map = joblib.load("/lakehouse/default/Files/annoy/ocean_map.pkl")

print("‚úÖ Prebuilt Annoy indexes loaded.")


# unified_discovery function
def unified_discovery(
    target_user_id: str,
    ocean_weight: float = 0.2,
    base_k_min: int = 200,
    base_k_max: int = 1000,
    top_n: int = 30
):
    print(f"\nüéØ Unified Discovery for user: {target_user_id}")
    print(f"üåä Using ocean_weight = {ocean_weight}")

    # --------------------------------------------------------
    # 1) Load user embedding
    # --------------------------------------------------------
    user_row = (
        spark.table(USER_EMB_TABLE)
        .filter(col("spotify_user_id") == target_user_id)
        .limit(1)
        .collect()
    )
    if not user_row:
        print("‚ö† User not found in user embedding table.")
        return None

    user_vec = np.array(user_row[0]["vector"], dtype="float32")

    # --------------------------------------------------------
    # 2) Block known tracks (positives)
    # --------------------------------------------------------
    history_rows = (
        spark.table(GOLD_LABEL_TABLE)
        .filter(col("spotify_user_id") == target_user_id)
        .filter(col("label") == 1)
        .select("spotify_id")
        .distinct()
        .collect()
    )
    history_set = {r.spotify_id for r in history_rows}
    print(f"üö´ Blocking {len(history_set)} known tracks.")

    # --------------------------------------------------------
    # 3) Retrieve using prebuilt ANN indexes
    # --------------------------------------------------------
    activity = max(1, len(history_set))
    k = int(np.clip(activity * 10, base_k_min, base_k_max))
    print(f"üìå Retrieval size k = {k}")

    ocean_weight = float(np.clip(ocean_weight, 0.0, 1.0))
    k_ocean   = int(k * ocean_weight)
    k_friends = k - k_ocean

    print(f"üîÑ Retrieval split: {k_ocean} OCEAN + {k_friends} FRIENDS")

    # ---- FRIENDS retrieval ----
    friend_ids = (
        friends_index.get_nns_by_vector(user_vec, k_friends)
        if k_friends > 0 else []
    )
    friend_candidates = [friends_map[i] for i in friend_ids]

    # ---- OCEAN retrieval ----
    ocean_ids = (
        ocean_index.get_nns_by_vector(user_vec, k_ocean)
        if k_ocean > 0 else []
    )
    ocean_candidates = [ocean_map[i] for i in ocean_ids]

    # Merge + remove duplicates + remove known tracks
    candidates = list({
        cid for cid in (friend_candidates + ocean_candidates)
        if cid not in history_set
    })

    print(f"üåä OCEAN used:   {len(ocean_candidates)}")
    print(f"ü§ù FRIENDS used: {len(friend_candidates)}")
    print(f"‚úÖ Total unseen candidates: {len(candidates)}")

    if not candidates:
        print("‚ö† No unseen candidates found.")
        return None

    # --------------------------------------------------------
    # 4) Join candidates with catalog (audio features)
    # --------------------------------------------------------
    candidates_df = spark.createDataFrame(
        [(target_user_id, sid) for sid in candidates],
        ["spotify_user_id", "spotify_id"]
    )

    recs = candidates_df.join(catalog_enriched, "spotify_id", "left")

    if AVOID_ARABIC:
        recs = recs.filter(~col("track_name").rlike(arabic_pattern))

    # Fill missing audio features
    recs = recs.fillna({c: 0.0 for c in AUDIO_FEATURES})

    # --------------------------------------------------------
    # 5) SCORE with LightGBM
    # --------------------------------------------------------
    pdf = recs.select(
        "spotify_id",
        "track_name",
        "artist_name",
        *AUDIO_FEATURES
    ).toPandas()

    if pdf.empty:
        print("‚ö† No rows to score after filtering.")
        return None

    X = pdf[AUDIO_FEATURES].values.astype("float32")
    scores = lgbm_model.predict_proba(X)[:, 1]

    pdf["score"] = scores

    preds = spark.createDataFrame(pdf)

    # --------------------------------------------------------
    # 6) Rank + dedupe
    # --------------------------------------------------------
    w = Window.partitionBy("track_name", "artist_name").orderBy(F.desc("score"))

    final_df = (
        preds.withColumn("rk", F.row_number().over(w))
             .filter(col("rk") == 1)
             .orderBy(F.desc("score"))
             .limit(top_n)
             .select(
                 "track_name",
                 "artist_name",
                 "score",
                 F.round("tempo", 1).alias("tempo"),
                 F.round("energy", 2).alias("energy"),
                 F.round("danceability", 2).alias("dance"),
                 F.round("valence", 2).alias("mood"),
                 F.round("loudness", 1).alias("loud"),
                 F.round("acousticness", 2).alias("acoust"),
             )
    )

    print(f"\nüéß Top {top_n} Recommendations:")
    final_df.show(truncate=False)

    return final_df


StatementMeta(, 18184d11-bac4-4d8a-a7b4-898a7859c4e4, 11, Finished, Available, Finished)

‚úÖ Prebuilt Annoy indexes loaded.


In [None]:
# ============================================
# CELL 3 ‚Äî BATCH RECOMMENDER JOB (PRODUCTION)
# ============================================

from pyspark.sql.functions import current_timestamp

USER_TABLE   = "silver_user_profile"
OUTPUT_TABLE = "gold_recommendations"
TOP_N        = 15

# Load all users with emails + display names
user_df = (
    spark.table(USER_TABLE)
    .select("spotify_user_id", "email", "display_name")
    .dropna(subset=["spotify_user_id"])
    .distinct()
)

# ============================================
# LOAD MODEL METADATA
# ============================================

meta = spark.table("gold_model_metadata").orderBy(F.desc("trained_at")).limit(1).collect()[0]

MODEL_VERSION = meta["model_version"]
MODEL_TRAINED_AT = meta["trained_at"]

print("üìå Loaded model metadata:")
print("MODEL_VERSION   =", MODEL_VERSION)
print("MODEL_TRAINED_AT=", MODEL_TRAINED_AT)


users = user_df.collect()
print(f"üë• Users to process: {len(users):,}")

batch_rows = []

for user in users:
    uid = user.spotify_user_id
    email = user.email
    display = user.display_name

    print(f"\n‚û° Recommending for: {uid} ({display})")

    try:
        recs = unified_discovery(uid, ocean_weight=0.2, top_n=TOP_N)
        if recs is None:
            continue

        pdf = recs.toPandas()

        for r in pdf.itertuples():
            batch_rows.append((
                uid,
                email,
                display,
                r.track_name,
                r.artist_name,
                float(r.score),
                MODEL_VERSION,
                MODEL_TRAINED_AT
            ))

    except Exception as e:
        print(f"‚ö† Error for {uid}: {e}")
        continue


# Safety
if not batch_rows:
    raise Exception("‚ùå No recommendations generated. batch_rows is empty!")

# Convert to Spark DF
final_df = spark.createDataFrame(
    batch_rows,
    [
        "spotify_user_id",
        "email",
        "display_name",
        "track_name",
        "artist_name",
        "score",
        "model_version",
        "model_trained_at"
    ]
)

# Add batch timestamp
final_df = final_df.withColumn("batch_generated_at", current_timestamp())

# Save to Delta table
final_df.write.mode("overwrite").saveAsTable(OUTPUT_TABLE)

print(f"‚úÖ Saved recommendations to {OUTPUT_TABLE}")
print(f"üì¶ Rows: {final_df.count():,}")
