## Item emb and retrieval top-k item

### Ensure sentence-transformers is installed on your cluster

In [0]:
%pip install --upgrade pip
%pip install sentence-transformers
%pip install faiss-cpu

[33mDEPRECATION: Using the pkg_resources metadata backend is deprecated. pip 26.3 will enforce this behaviour change. A possible replacement is to use the default importlib.metadata backend, by unsetting the _PIP_USE_IMPORTLIB_METADATA environment variable. Discussion can be found at https://github.com/pypa/pip/issues/13317[0m[33m
[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m
[33mDEPRECATION: Using the pkg_resources metadata backend is deprecated. pip 26.3 will enforce this behaviour change. A possible replacement is to use the default importlib.metadata backend, by unsetting the _PIP_USE_IMPORTLIB_METADATA environment variable. Discussion can be found at https://github.com/pypa/pip/issues/13317[0m[33m
[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m
[33mDEPRECATION: Using the pkg_resources metadata backend is de

### Import library

In [0]:
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.ml.feature import BucketedRandomProjectionLSH
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql.types import ArrayType, StringType, StructType, StructField, FloatType
from sentence_transformers import SentenceTransformer
import pandas as pd
import os
import time
import faiss
import numpy as np
from pyspark.sql.functions import pandas_udf

### Configuration

In [0]:
GOLD_DATABASE = "`bigdata-and-bi`.gold"

# Input tables 
GOLD_ITEMS_TABLE = f"{GOLD_DATABASE}.star_items"

# Output tables
ITEM_VECTORS_TABLE = f"{GOLD_DATABASE}.star_item_vectors"

EMBEDDING_MODEL = 'all-MiniLM-L6-v2' # Fast, effective embedding model

### Load Gold Data

In [0]:
print("Loading Gold tables...")
try:
    items_df = spark.table(GOLD_ITEMS_TABLE)
    
    # Trigger an action to ensure tables exist
    items_df.count()
    print("Successfully loaded Gold tables (Items).")
except Exception as e:
    print(f"ERROR: Could not load tables. Please run Notebook 1 first.")
    dbutils.notebook.exit(f"Failed to read tables: {e}")

Loading Gold tables...
Successfully loaded Gold tables (Items).


### Generate Item Embeddings (Semantic Foundation - Rs)

In [0]:
print(f"Generating item embeddings using '{EMBEDDING_MODEL}'...")

# Setup schema
schema = StructType([
    StructField("item_id", StringType(), False),
    StructField("vector", VectorUDT(), False)
])

# Check if table exists (FIXED VERSION)
def table_exists(table_name):
    """Check if Delta table exists"""
    try:
        # Try to read table and trigger an action
        spark.table(table_name).limit(1).count()
        return True
    except Exception as e:
        # Table doesn't exist or not accessible
        return False

# Filter already processed items (WITH ERROR HANDLING)
if table_exists(ITEM_VECTORS_TABLE):
    try:
        existing_ids = spark.table(ITEM_VECTORS_TABLE).select("item_id")
        items_to_process = items_df.join(existing_ids, "item_id", "left_anti")
        existing_count = existing_ids.count()
        print(f"Found existing table with {existing_count:,} vectors")
        print(f"Filtered already processed items")
    except Exception as e:
        print(f"Warning reading existing table: {e}")
        print("Will process all items")
        items_to_process = items_df
else:
    items_to_process = items_df
    print("Table doesn't exist yet - will create new table")
    print("Processing all items")

# Get count safely
try:
    remaining_count = items_to_process.count()
    print(f"Total items to process: {remaining_count:,}")
except Exception as e:
    print(f"Error counting items: {e}")
    raise

if remaining_count == 0:
    print("All items already processed!")
else:
    # === OPTIMIZED DRIVER-ONLY PROCESSING ===
    
    print("\nLoading embedding model on driver...")
    model = SentenceTransformer(EMBEDDING_MODEL)
    print("Model loaded")
    
    # Configuration - OPTIMIZED for speed
    BATCH_SIZE = 50000  # Process 50K items per batch (collect less frequently)
    ENCODING_BATCH_SIZE = 512  # Encode 512 texts at once (GPU/CPU batch)
    
    print(f"\nConfiguration:")
    print(f"   • Batch size: {BATCH_SIZE:,} items")
    print(f"   • Encoding batch: {ENCODING_BATCH_SIZE}")
    print(f"   • Total batches: {(remaining_count // BATCH_SIZE) + 1}")
    
    start_time = time.time()
    total_processed = 0
    batch_num = 0
    
    # Process in batches
    while total_processed < remaining_count:
        batch_num += 1
        batch_start = time.time()
        
        print(f"\n{'='*60}")
        print(f"Batch {batch_num} / ~{(remaining_count // BATCH_SIZE) + 1}")
        print(f"{'='*60}")
        
        # Get batch
        try:
            batch_data = items_to_process.limit(BATCH_SIZE).collect()
        except Exception as e:
            print(f"Error collecting batch: {e}")
            break
        
        if len(batch_data) == 0:
            print("No more items to process")
            break
        
        # Extract data
        item_ids = [row.item_id for row in batch_data]
        texts = [row.prompt_text for row in batch_data]
        
        print(f"   Items in batch: {len(texts):,}")
        
        # Generate embeddings
        try:
            print(f"Encoding embeddings...")
            embeddings = model.encode(
                texts,
                batch_size=ENCODING_BATCH_SIZE,
                show_progress_bar=True,
                convert_to_numpy=True,
                normalize_embeddings=False
            )
            print(f"Encoding complete")
        except Exception as e:
            print(f"Error encoding: {e}")
            break
        
        # Create results
        results = [
            {
                "item_id": item_ids[i], 
                "vector": Vectors.dense(embeddings[i].tolist())
            }
            for i in range(len(item_ids))
        ]
        
        # Save to Delta
        try:
            print(f"Saving to Delta table...")
            batch_df = spark.createDataFrame(results, schema=schema)
            
            # First batch: create table with overwrite, rest: append
            if batch_num == 1 and not table_exists(ITEM_VECTORS_TABLE):
                batch_df.write \
                    .format("delta") \
                    .mode("overwrite") \
                    .saveAsTable(ITEM_VECTORS_TABLE)
                print(f"Created table and saved first batch")
            else:
                batch_df.write \
                    .format("delta") \
                    .mode("append") \
                    .saveAsTable(ITEM_VECTORS_TABLE)
                print(f"Appended to table")
                
        except Exception as e:
            print(f"Error saving: {e}")
            print(f"Check table name: {ITEM_VECTORS_TABLE}")
            break
        
        # Update progress
        total_processed += len(texts)
        batch_elapsed = time.time() - batch_start
        batch_speed = len(texts) / batch_elapsed if batch_elapsed > 0 else 0
        
        # Statistics
        print(f"\nBatch Statistics:")
        print(f"      • Time: {batch_elapsed:.1f}s")
        print(f"      • Speed: {batch_speed:.1f} items/s")
        print(f"      • Processed: {total_processed:,} / {remaining_count:,} ({100*total_processed/remaining_count:.1f}%)")
        
        # Calculate ETA
        if total_processed > 0:
            elapsed_total = time.time() - start_time
            avg_speed = total_processed / elapsed_total
            remaining = remaining_count - total_processed
            eta_seconds = remaining / avg_speed if avg_speed > 0 else 0
            eta_hours = eta_seconds / 3600
            
            print(f"      • Overall speed: {avg_speed:.1f} items/s")
            print(f"      • ETA: {eta_hours:.1f} hours ({eta_seconds/60:.0f} minutes)")
        
        # Remove processed items from queue
        try:
            processed_ids = spark.createDataFrame(
                [{"item_id": id} for id in item_ids]
            )
            items_to_process = items_to_process.join(
                processed_ids, "item_id", "left_anti"
            )
        except Exception as e:
            print(f"Warning: Could not update queue: {e}")
            # Continue anyway - duplicate handling will be done by Delta
    
    # Final summary
    total_elapsed = time.time() - start_time
    final_speed = total_processed / total_elapsed if total_elapsed > 0 else 0
    
    print(f"\n{'='*60}")
    print(f"PROCESSING COMPLETE")
    print(f"{'='*60}")
    print(f"Processed: {total_processed:,} items")
    print(f"Total time: {total_elapsed/60:.1f} minutes ({total_elapsed/3600:.2f} hours)")
    print(f"Average speed: {final_speed:.1f} items/second")
    print(f"{'='*60}")

# Verify final count (SAFE VERSION)
print("\nVerifying final count...")
try:
    if table_exists(ITEM_VECTORS_TABLE):
        final_count = spark.table(ITEM_VECTORS_TABLE).count()
        print(f"Total vectors in table: {final_count:,}")
        
        # Show sample
        print("\nSample results:")
        spark.table(ITEM_VECTORS_TABLE).select("item_id").show(5, truncate=False)
    else:
        print("Table was not created (no items processed)")
except Exception as e:
    print(f" Could not verify table: {e}")

{"ts": "2025-11-14 08:53:42.791", "level": "ERROR", "logger": "pyspark.sql.connect.logging", "msg": "GRPC Error received", "context": {}, "exception": {"class": "_MultiThreadedRendezvous", "msg": "<_MultiThreadedRendezvous of RPC that terminated with:\n\tstatus = StatusCode.INTERNAL\n\tdetails = \"[TABLE_OR_VIEW_NOT_FOUND] The table or view `bigdata-and-bi`.`gold`.`star_item_vectors` cannot be found. Verify the spelling and correctness of the schema and catalog.\nIf you did not qualify the name with a schema, verify the current_schema() output, or qualify the name with the correct schema and catalog.\nTo tolerate the error on drop use DROP VIEW IF EXISTS or DROP TABLE IF EXISTS. SQLSTATE: 42P01;\n'Aggregate [unresolvedalias(count(1))]\n+- 'GlobalLimit 1\n   +- 'LocalLimit 1\n      +- 'UnresolvedRelation [bigdata-and-bi, gold, star_item_vectors], [], false\n\"\n\tdebug_error_string = \"UNKNOWN:Error received from peer  {created_time:\"2025-11-14T08:53:42.791075021+00:00\", grpc_status:1

Generating item embeddings using 'all-MiniLM-L6-v2'...
⚠️  Note: Using driver-only processing due to Serverless limitations
✓ Table doesn't exist yet - will create new table
✓ Processing all items


INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: cpu
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-MiniLM-L6-v2


📊 Total items to process: 495,062

🔧 Loading embedding model on driver...
✅ Model loaded

⚙️  Configuration:
   • Batch size: 50,000 items
   • Encoding batch: 512
   • Total batches: 10

📦 Batch 1 / ~10
   Items in batch: 50,000
   🔄 Encoding embeddings...


Batches:   0%|          | 0/98 [00:00<?, ?it/s]

IOStream.flush timed out


   ✅ Encoding complete
   💾 Saving to Delta table...


{"ts": "2025-11-14 09:23:34.201", "level": "ERROR", "logger": "pyspark.sql.connect.logging", "msg": "GRPC Error received", "context": {}, "exception": {"class": "_MultiThreadedRendezvous", "msg": "<_MultiThreadedRendezvous of RPC that terminated with:\n\tstatus = StatusCode.INTERNAL\n\tdetails = \"[TABLE_OR_VIEW_NOT_FOUND] The table or view `bigdata-and-bi`.`gold`.`star_item_vectors` cannot be found. Verify the spelling and correctness of the schema and catalog.\nIf you did not qualify the name with a schema, verify the current_schema() output, or qualify the name with the correct schema and catalog.\nTo tolerate the error on drop use DROP VIEW IF EXISTS or DROP TABLE IF EXISTS. SQLSTATE: 42P01;\n'Aggregate [unresolvedalias(count(1))]\n+- 'GlobalLimit 1\n   +- 'LocalLimit 1\n      +- 'UnresolvedRelation [bigdata-and-bi, gold, star_item_vectors], [], false\n\"\n\tdebug_error_string = \"UNKNOWN:Error received from peer  {grpc_message:\"[TABLE_OR_VIEW_NOT_FOUND] The table or view `bigdata

   ✅ Created table and saved first batch

   📊 Batch Statistics:
      • Time: 1794.8s
      • Speed: 27.9 items/s
      • Processed: 50,000 / 495,062 (10.1%)
      • Overall speed: 27.9 items/s
      • ETA: 4.4 hours (266 minutes)

📦 Batch 2 / ~10
   Items in batch: 50,000
   🔄 Encoding embeddings...


Batches:   0%|          | 0/98 [00:00<?, ?it/s]

com.databricks.backend.common.rpc.CommandCancelledException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$5(SequenceExecutionState.scala:139)
	at scala.Option.getOrElse(Option.scala:201)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:139)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:136)
	at scala.collection.immutable.Range.foreach(Range.scala:192)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:136)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:721)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:441)
	at scala.Option.getOrElse(Option.scala:201)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:441)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.can