In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from datetime import datetime
import pytz
import time

spark.conf.set("spark.sql.shuffle.partitions", "4")

# Paths & Config
CATALOG = "testing"
SCHEMA = "processed_data"
INPUT_PATH = "/Volumes/testing/processed_data/staging_volume/staging_data/"
OUTPUT_PATH = "/Volumes/testing/processed_data/staging_volume/output_detections_data/"
IMPORTANCE_TABLE = "transactions.financial_transactions_google_drive.customer_importance"
PROGRESS_TABLE = f"{CATALOG}.{SCHEMA}.detection_progress"

# ───────────────
# Load importance table once and clean
# ───────────────
imp_df = spark.read.table(IMPORTANCE_TABLE) \
    .select("source", "target", "type_trans", "weight") \
    .withColumn("source", trim(lower(regexp_replace(col("source"), "'", "")))) \
    .withColumn("target", trim(lower(regexp_replace(col("target"), "'", "")))) \
    .withColumn("type_trans", trim(lower(regexp_replace(col("type_trans"), "'", "")))) \
    .withColumn("type_trans", regexp_replace(col("type_trans"), "_", "")) \
    .withColumn("weight", col("weight").cast("double")) \
    .dropna()

# Get current IST time
def ist_now():
    return datetime.now(pytz.timezone("Asia/Kolkata")).strftime('%Y-%m-%d %H:%M:%S')

# Detect patterns
def detect_patterns(df):
    ystart = ist_now()

    # Pattern 1: UPGRADE
    joined = df.join(imp_df, (df["merchant"] == imp_df["target"]) &
                               (df["customer"] == imp_df["source"]) &
                               (df["category"] == imp_df["type_trans"]), "inner")
    agg_p1 = joined.groupBy("merchant", "customer").agg(
        count("*").alias("txn_count"),
        avg("weight").alias("avg_weight")
    )
    total_txn = joined.groupBy("merchant").agg(count("*").alias("total_txns"))
    agg_p1 = agg_p1.join(total_txn, "merchant").filter(col("total_txns") >= 50000)

    if agg_p1.count() > 0:
        txn_thresh = agg_p1.approxQuantile("txn_count", [0.9], 0.05)[0]
        weight_thresh = agg_p1.approxQuantile("avg_weight", [0.1], 0.05)[0]
        p1 = agg_p1.filter((col("txn_count") >= txn_thresh) & (col("avg_weight") <= weight_thresh)) \
            .selectExpr(f"'{ystart}' as YStartTime", "current_timestamp() as detectionTime",
                        "'PatId1' as patternId", "'UPGRADE' as ActionType",
                        "customer as customerName", "merchant as merchantId")
    else:
        p1 = spark.createDataFrame([], "YStartTime STRING, detectionTime TIMESTAMP, patternId STRING, ActionType STRING, customerName STRING, merchantId STRING")

    # Pattern 2: CHILD
    p2 = df.groupBy("merchant", "customer").agg(
        avg("amount").alias("avg_amount"),
        count("*").alias("txn_count")
    ).filter((col("avg_amount") < 23) & (col("txn_count") >= 80)) \
     .selectExpr(f"'{ystart}' as YStartTime", "current_timestamp() as detectionTime",
                 "'PatId2' as patternId", "'CHILD' as ActionType",
                 "customer as customerName", "merchant as merchantId")

    # Pattern 3: DEI-NEEDED
    gender_df = df.withColumn("gender_norm", when(col("gender") == "m", "Male")
                                           .when(col("gender") == "f", "Female")
                                           .otherwise(None))
    p3_raw = gender_df.dropna(subset=["gender_norm"]).select("merchant", "customer", "gender_norm") \
                      .distinct().groupBy("merchant") \
                      .pivot("gender_norm", ["Male", "Female"]).count().na.fill(0)
    p3 = p3_raw.filter((col("Female") > 100) & (col("Female") < col("Male"))) \
        .selectExpr(f"'{ystart}' as YStartTime", "current_timestamp() as detectionTime",
                    "'PatId3' as patternId", "'DEI-NEEDED' as ActionType",
                    "'' as customerName", "merchant as merchantId")

    detections = p1.unionByName(p2).unionByName(p3)
    return detections

# Read last progress
def get_last_processed_ts():
    if spark.catalog.tableExists(PROGRESS_TABLE):
        return spark.read.table(PROGRESS_TABLE).agg({"last_ingestion_timestamp": "max"}).collect()[0][0]
    return None

# Save progress
def update_progress(ts):
    progress_df = spark.createDataFrame([(ts,)], ["last_ingestion_timestamp"])
    progress_df.write.mode("overwrite").saveAsTable(PROGRESS_TABLE)

# --- Main Loop ---
no_data_counter = 0
MAX_RETRIES = 5

while no_data_counter < MAX_RETRIES:
    try:
        last_ts = get_last_processed_ts()
        df_batch = spark.read.format("delta").load(INPUT_PATH)

        if last_ts:
            df_batch = df_batch.filter(col("ingestion_timestamp") > lit(last_ts))

        if df_batch.count() > 0:
            print(f"\n📥 New records found since {last_ts}")

            # ───────────────
            # Clean & preprocess current chunk
            # ───────────────
            chunk = df_batch \
                .withColumn("merchant", trim(lower(regexp_replace(col("merchant"), "'", "")))) \
                .withColumn("customer", trim(lower(regexp_replace(col("customer"), "'", "")))) \
                .withColumn("category", trim(lower(regexp_replace(col("category"), "'", "")))) \
                .withColumn("amount", col("amount").cast("double")) \
                .withColumn("gender", trim(lower(col("gender")))) \
                .dropna(subset=["merchant", "customer"])

            result_df = detect_patterns(chunk)

            if result_df.count() > 0:
                result_df = result_df.withColumn("batch_id", floor(monotonically_increasing_id() / 50))
                result_df.write.format("delta").mode("append").partitionBy("batch_id").save(OUTPUT_PATH)
                print(f"✅ Wrote {result_df.count()} detections to {OUTPUT_PATH}")
            else:
                print("⚠️ No detections found in this batch")

            latest_ts = df_batch.agg({"ingestion_timestamp": "max"}).collect()[0][0]
            update_progress(latest_ts)

            no_data_counter = 0
        else:
            no_data_counter += 1
            print(f"⏳ No new data to process... ({no_data_counter}/{MAX_RETRIES})")

    except Exception as e:
        print(f"❌ Error: {e}")

    time.sleep(1)

print("\n🛑 No new data after 5 checks. Exiting detection loop.")


In [0]:
spark.read.format("delta").load(OUTPUT_PATH).display()

In [0]:
OUTPUT_PATH = "/Volumes/testing/processed_data/staging_volume/output_detections_data/"
spark.read.format("delta").load(OUTPUT_PATH).filter(col("ActionType") == "CHILD").display()
