In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import concat_ws
from sentence_transformers import SentenceTransformer
import pandas as pd
import numpy as np

# Step 1: Spark session
spark = SparkSession.builder \
    .appName("DistributedEmbedding") \
    .config("spark.executor.memory", "8g") \
    .config("spark.driver.memory", "16g") \
    .config("spark.executor.cores", "4") \
    .config("spark.sql.shuffle.partitions", "100") \
    .enableHiveSupport() \
    .getOrCreate()

# Step 2: Load Hive table
df = spark.sql("SELECT transaction_id, user_id, amount, category, country FROM datamart.fraud_transactions")

# Step 3: Add 'text' column for embedding input
df = df.withColumn("text", concat_ws(" ", "user_id", "amount", "category", "country"))

# Step 4: Define function to run inside Spark workers
def embed_partition(pdf: pd.DataFrame) -> pd.DataFrame:
    from sentence_transformers import SentenceTransformer
    model = SentenceTransformer("BAAI/bge-small-en-v1.5")  # GPU-enabled if available

    texts = pdf["text"].tolist()
    embeddings = list(model.encode(texts, show_progress_bar=False, batch_size=64))  # FIXED
    embedding_cols = [f"emb_{i}" for i in range(len(embeddings[0]))]
    emb_df = pd.DataFrame(embeddings, columns=embedding_cols)
    emb_df["transaction_id"] = pdf["transaction_id"].values
    return emb_df

# Step 5: Run distributed embedding using mapInPandas
schema = "transaction_id LONG, " + ", ".join([f"emb_{i} FLOAT" for i in range(384)])  # 384 dims for bge-small

embedding_df = df.select("transaction_id", "text") \
    .repartition(100) \
    .mapInPandas(embed_partition, schema=schema)

# Step 6: Write embeddings to Hive or Parquet
embedding_df.write.mode("overwrite").parquet("hdfs:///tmp/fraud_embeddings_parquet")
print("✅ Embeddings written to HDFS: /tmp/fraud_embeddings_parquet")