In [None]:
from sentence_transformers import SentenceTransformer
from pyspark.sql import SparkSession
from pyspark.sql.functions import concat_ws
import time
import numpy as np

# Step 1: Start Spark with Hive support and increased result size
spark = SparkSession.builder \
    .appName("FraudEmbeddingGeneration") \
    .config("spark.executor.memory", "8g") \
    .config("spark.driver.memory", "16g") \
    .config("spark.driver.maxResultSize", "4g") \
    .config("spark.executor.cores", "4") \
    .config("spark.sql.shuffle.partitions", "16") \
    .enableHiveSupport() \
    .getOrCreate()

# Step 2: Load Hive table and prepare text input
print("📥 Reading from Hive table datamart.fraud_transactions...")
df = spark.sql("SELECT * FROM datamart.fraud_transactions")

# Add a 'text' column to describe the transaction in natural language format
df_text = df.withColumn("text", concat_ws(" ", "user_id", "amount", "category", "country"))

# Step 3: Stream rows using toLocalIterator to avoid OOM
texts = []
print("⚙️ Collecting rows from Spark with toLocalIterator...")
for row in df_text.select("text").toLocalIterator():
    texts.append(row["text"])

print(f"✅ Collected {len(texts)} records")

# Step 4: Load embedding model (GPU-enabled)
print("🧠 Loading model on GPU...")
model = SentenceTransformer("BAAI/bge-small-en-v1.5")  # Uses CUDA if available

# Step 5: Generate embeddings
print("🚀 Generating embeddings...")
start = time.time()
embeddings = model.encode(texts, show_progress_bar=True, batch_size=64)
duration = time.time() - start
print(f"✅ Generated {len(embeddings)} embeddings in {duration:.2f} seconds")

# Optional: Save embeddings to local disk or vector database
np.save("embeddings.npy", embeddings)
print("💾 Saved embeddings to embeddings.npy")
