In [0]:
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.functions import col, lit, current_timestamp, when, count, mean, expr
from datetime import datetime


customer_df = (
    spark.read.option("header", "true")
    .csv("abfss://gdrive-ingest@devdolphinstorage.dfs.core.windows.net/reference_data/customer data.csv")
)

merchant_transaction_count_df = None
customer_merchant_stats_df = None
already_detected_pat1 = set()

streaming_df = (
    spark.readStream
    .format("cloudFiles")
    .option("cloudFiles.format", "csv")
    .option("header", "true")
    .option("cloudFiles.schemaLocation", "abfss://gdrive-ingest@devdolphinstorage.dfs.core.windows.net/schema/ChunksSchema/")
    .load("abfss://gdrive-ingest@devdolphinstorage.dfs.core.windows.net/transactions/")
)

def foreach_batch_function(batch_df, batch_id):
    global merchant_transaction_count_df, customer_merchant_stats_df, already_detected_pat1

    print(f"\n Processing batch {batch_id} rows: {batch_df.count()}")

    # Join with reference customer_df
    merged_df = batch_df.join(
        customer_df,
        (batch_df["customer"] == customer_df["Source"]) &
        (batch_df["merchant"] == customer_df["Target"]) &
        (batch_df["category"] == customer_df["typeTrans"]) &
        (batch_df["amount"] == customer_df["Weight"]),
        how="inner"
    )

    print(f" Merged rows: {merged_df.count()}")

  
    merchant_txn = (
        merged_df.groupBy("merchant")
        .agg(count("*").alias("new_txn_count"))
    )

    if merchant_transaction_count_df is None:
        merchant_transaction_count_df = merchant_txn.withColumnRenamed("new_txn_count", "total_txn")
    else:
        left = merchant_transaction_count_df
        right = merchant_txn

        merchant_transaction_count_df = (
            left.join(right, on="merchant", how="outer")
            .na.fill(0)
            .withColumn("total_txn", col("total_txn") + col("new_txn_count"))
            .select("merchant", "total_txn")
        )

    print(f" Updated merchant_transaction_count_df count: {merchant_transaction_count_df.count()}")


    cust_merchant_stats = (
        merged_df.groupBy("customer", "merchant")
        .agg(
            count("*").alias("new_txn_count"),
            mean("Weight").alias("new_avg_weight")
        )
    )

    if customer_merchant_stats_df is None:
        customer_merchant_stats_df = cust_merchant_stats.withColumnRenamed("new_txn_count", "txn_count") \
                                                         .withColumnRenamed("new_avg_weight", "avg_weight")
    else:
        left = customer_merchant_stats_df
        right = cust_merchant_stats

        combined = (
            left.join(right, on=["customer", "merchant"], how="outer")
            .na.fill(0)
            .withColumn("txn_count", col("txn_count") + col("new_txn_count"))
            .withColumn(
                "avg_weight",
                when(col("txn_count") == 0, 0).otherwise(
                    (col("avg_weight") * col("txn_count") + col("new_avg_weight") * col("new_txn_count")) /
                    (col("txn_count") + col("new_txn_count"))
                )
            )
            .select("customer", "merchant", "txn_count", "avg_weight")
        )

        customer_merchant_stats_df = combined

    print(f" Updated customer_merchant_stats_df count: {customer_merchant_stats_df.count()}")


    eligible_merchants = merchant_transaction_count_df.filter(col("total_txn") > 50000).select("merchant").distinct()
    eligible_merchants_list = [row["merchant"] for row in eligible_merchants.collect()]

    detections = []
    for merchant in eligible_merchants_list:
        cust_stats = customer_merchant_stats_df.filter(col("merchant") == merchant)

        if cust_stats.count() == 0:
            continue

        txn_thresh = cust_stats.approxQuantile("txn_count", [0.9], 0.0)[0]
        weight_thresh = cust_stats.approxQuantile("avg_weight", [0.1], 0.0)[0]

        detected = cust_stats.filter(
            (col("txn_count") >= txn_thresh) & (col("avg_weight") <= weight_thresh)
        ).select("customer", "merchant")

        for row in detected.collect():
            key = (row["customer"], row["merchant"])
            if key in already_detected_pat1:
                continue

            detections.append({
                "YStartTime": datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                "detectionTime": datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                "patternId": "PatId1",
                "ActionType": "UPGRADE",
                "customerName": row["customer"],
                "MerchantId": row["merchant"]
            })
            already_detected_pat1.add(key)

    if detections:
        detections_df = spark.createDataFrame(detections)
        detections_df.show(truncate=False)

        detections_df.write.jdbc(jdbc_url, "already_detected_pat1", mode="append", properties=jdbc_props)

        print(f" Wrote {len(detections)} Pattern 1 detections to Postgres.")
    else:
        print(" No new detections for Pattern 1.")

# ----------------------------------------
# ✅ Start Stream
# ----------------------------------------
query = (
    streaming_df.writeStream
    .foreachBatch(foreach_batch_function)
    .outputMode("append")
    .option("checkpointLocation", "abfss://gdrive-ingest@devdolphinstorage.dfs.core.windows.net/checkpoints/pat1/")
    .start()
)

query.awaitTermination()
