In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import concat_ws
from sentence_transformers import SentenceTransformer
import pandas as pd
import numpy as np

In [None]:
# Step 1: Spark session
spark = SparkSession.builder \
    .appName("DistributedEmbedding") \
    .config("spark.executor.memory", "8g") \
    .config("spark.driver.memory", "8g") \
    .config("spark.executor.cores", "8") \
    .config("spark.sql.shuffle.partitions", "16") \
    .config("spark.sql.execution.arrow.maxRecordsPerBatch", "10000") \
    .config("spark.task.maxFailures", "8") \
    .enableHiveSupport() \
    .getOrCreate()


In [None]:
# Step 2: Load Hive table
df = spark.sql("SELECT transaction_id, user_id, amount, category, country FROM datamart.fraud_transactions")


In [None]:
# Step 3: Add 'text' column for embedding input
df = df.withColumn("text", concat_ws(" ", "user_id", "amount", "category", "country"))


In [None]:
# Step 4: Define function to run inside Spark workers
def embed_partition(pdf: pd.DataFrame) -> pd.DataFrame:
    from sentence_transformers import SentenceTransformer
    import pandas as pd

    model = SentenceTransformer("BAAI/bge-small-en-v1.5")  # GPU-enabled if available
    texts = pdf["text"].tolist()
    embeddings = list(model.encode(texts, show_progress_bar=False, batch_size=64))
    embedding_cols = [f"emb_{i}" for i in range(len(embeddings[0]))]
    emb_df = pd.DataFrame(embeddings, columns=embedding_cols)
    emb_df["transaction_id"] = pdf["transaction_id"].values
    return emb_df

In [None]:
# Step 5: Run distributed embedding using mapInPandas
schema = "transaction_id LONG, " + ", ".join([f"emb_{i} FLOAT" for i in range(384)])  # 384 dims for bge-small

embedding_df = df.select("transaction_id", "text") \
    .repartition(32) \
    .mapInPandas(embed_partition, schema=schema) \
    .persist()

In [None]:
# Step 6: Write embeddings to HDFS with optimized settings
embedding_df.coalesce(4).write \
    .mode("overwrite") \
    .option("compression", "snappy") \
    .parquet("hdfs:///tmp/fraud_embeddings_parquet")

print("âœ… Embeddings written to HDFS: /tmp/fraud_embeddings_parquet")