In [0]:
import sys
import time
import math
from datetime import datetime, timedelta
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, window, expr, udf, lit, to_timestamp, mean, stddev, max, count, first, min, last, unix_timestamp, trim, collect_list, monotonically_increasing_id, current_timestamp, date_format, lpad, concat, explode, flatten, element_at, array_except, array, array_remove, slice, size, when, collect_set, broadcast, array_intersect, sum, row_number, struct, abs, any_value, to_json
from pyspark.sql.window import Window
from pyspark import StorageLevel
from pyspark.sql.types import IntegerType
import builtins as b

In [0]:
start_time = time.time()
serialized_storage = StorageLevel(useDisk=True, useMemory=True, useOffHeap=True, deserialized=False, replication=1)

In [0]:
snowflake_options_thres = {
    "sfURL": "https://dumzrka-fy75904.snowflakecomputing.com",
    "sfWarehouse": "TAKUMI_AN_01_WH",
    "sfDatabase": "TAKUMI_03",
    "sfSchema": "THRES",
    "sfRole": "DATA_ANALYST",
    "user": "TAKUMI_AN_01",
    "password": "TAKUMI_AN_01@123"
}

In [0]:
snowflake_options = {
    "sfURL": "https://dumzrka-fy75904.snowflakecomputing.com",
    "sfWarehouse": "TAKUMI_AN_01_WH",
    "sfDatabase": "TAKUMI_03",
    "sfSchema": "ETL",
    "sfRole": "DATA_ANALYST",
    "user": "TAKUMI_AN_01",
    "password": "TAKUMI_AN_01@123"
}

In [0]:
snowflake_options_alert = {
    "sfURL": "https://dumzrka-fy75904.snowflakecomputing.com",
    "sfWarehouse": "TAKUMI_AN_01_WH",
    "sfDatabase": "TAKUMI_03",
    "sfSchema": "ALERTS",
    "sfRole": "DATA_ANALYST",
    "user": "TAKUMI_AN_01",
    "password": "TAKUMI_AN_01@123"
}

In [0]:
time_for_query_session = time.time()
overall_query = """
    SELECT 
        MIN((DATE_PART('EPOCH_SECOND', TO_TIMESTAMP_NTZ(LEFT(transaction_timestamp, 19))) * 1000000) + CAST(RIGHT(transaction_timestamp, 9) AS BIGINT)) AS min_ts_ns,
        MAX((DATE_PART('EPOCH_SECOND', TO_TIMESTAMP_NTZ(LEFT(transaction_timestamp, 19))) * 1000000) + CAST(RIGHT(transaction_timestamp, 9) AS BIGINT)) AS max_ts_ns
    FROM  MOIG_TESTING
    WHERE (validation_flag IS NULL OR TRIM(validation_flag) = '') ;
"""

overall_df = spark.read.format("snowflake") \
    .options(**snowflake_options) \
    .option("query", overall_query) \
    .load()
    


overall_row = overall_df.collect()[0]
global_start_time = overall_row["MIN_TS_NS"]
global_end_time = overall_row["MAX_TS_NS"]

print(f"Global start time: {global_start_time}")
print(f"Global end time: {global_end_time}")

print(f"Time to query overall period: {time.time() - time_for_query_session} seconds")

Global start time: 1741858200000000
Global end time: 1741859978000038
Time to query overall period: 1.9349262714385986 seconds


In [0]:
overall_df.display()

MIN_TS_NS,MAX_TS_NS
1741858200000000,1741859978000038


In [0]:
asset_df = spark.read.format("snowflake").options(**snowflake_options_thres).option("query", """
    SELECT ASSET_ANALYTICS_ID as asset_an_id FROM ASSETANALYTICS WHERE ANALYTICS_ID = 1;
""").load()

asset_row = asset_df.collect()[0]
asset_analytics_id = asset_row["ASSET_AN_ID"]


# Read from Snowflake
query = f"""
    SELECT THRESHOLD_ID, THRESHOLD_IDENTIFIER, THRESHOLD_VALUE
    FROM AnalyticsThreshold
    WHERE asset_analytics_id = {asset_analytics_id}
"""
df = spark.read.format("snowflake").options(**snowflake_options_thres).option("query", query).load()

threshold_id_map = {}
# Collect and assign values dynamically
for row in df.collect():
    threshold_identifier = row["THRESHOLD_IDENTIFIER"]
    threshold_value = row["THRESHOLD_VALUE"]
    threshold_id = row["THRESHOLD_ID"]

    # Assign to global variable
    globals()[threshold_identifier] = threshold_value

    # Store ID in a map
    threshold_id_map[threshold_identifier] = int(threshold_id)

IGNITION_WINDOW = int(IGNITION_WINDOW)
FOLLOW_THROUGH = int(FOLLOW_THROUGH)
REVERSAL_WINDOW = int(REVERSAL_WINDOW)


IGNITION_WINDOW_NS = IGNITION_WINDOW * 60 * 10**6
FOLLOW_THROUGH_NS = FOLLOW_THROUGH * 60 * 10**6
REVERSAL_WINDOW_NS = REVERSAL_WINDOW * 60 * 10**6
TOTAL_WINDOW_NS = IGNITION_WINDOW_NS + FOLLOW_THROUGH_NS + REVERSAL_WINDOW_NS
CHUNK_SIZE = 1 * 60 * 60 * 10**6
THRES_PRICE_MOIG = MOIG_THRESHOLD_PRICE
THRES_QUANTITY_MOIG = MOIG_THRESHOLD_VOLUME
THRES_MOMENTUM_MOIG = MOIG_THRESHOLD_MOMENTUM / 100
THRES_TRADER_ID_MATCH_PERCENTAGE_REVERSAL = THRES_TRADER_ID_MATCH_PERCENTAGE_REVERSAL / 100
THRES_TRADER_ID_MATCH_PERCENTAGE_FOLLOW = THRES_TRADER_ID_MATCH_PERCENTAGE_FOLLOW / 100

# Example: print the values to check
print(IGNITION_WINDOW_NS, FOLLOW_THROUGH_NS, REVERSAL_WINDOW_NS, TOTAL_WINDOW_NS, PERCENTAGE_ADV30, CHUNK_SIZE, THRES_PRICE_MOIG, THRES_QUANTITY_MOIG, THRES_MOMENTUM_MOIG, THRES_TRADER_ID_MATCH_PERCENTAGE_REVERSAL, THRES_TRADER_ID_MATCH_PERCENTAGE_FOLLOW, THRES_IGNITION_BUY_SELL_RATIO, REVERSAL_THRESHOLD_SELL_BUY_RATIO)

300000000 120000000 300000000 720000000 1.0 3600000000 1.0 1.0 0.02 0.05 0.6 1.0 1.0


In [0]:

current_chunk_start = global_start_time
momentum_ignition_total_count = 0
next_window =  False  # Flag to indicate subsequent iterations
total_time_for_data = 0

while (current_chunk_start < global_end_time):
    time_to_check_empty = time.time()
    if current_chunk_start == global_start_time:
        current_chunk_end = current_chunk_start + CHUNK_SIZE # 1 hour in nanoseconds
    else:
        current_chunk_end = current_chunk_start + CHUNK_SIZE + TOTAL_WINDOW_NS

    query = f"""
        SELECT transaction_id, transaction_timestamp, symbol, price, quantity, validation_flag, trader_id, side_id, adv30
        FROM MOIG_TESTING
        WHERE (validation_flag IS NULL OR TRIM(validation_flag) = '')
        AND ((DATE_PART('EPOCH_SECOND', TO_TIMESTAMP_NTZ(LEFT(transaction_timestamp, 19))) * 1000000) +
                CAST(RIGHT(transaction_timestamp, 6) AS BIGINT))
            BETWEEN {current_chunk_start} AND {current_chunk_end}
        ORDER BY symbol, transaction_timestamp;
    """

    chunk_df = spark.read.format("snowflake") \
        .options(**snowflake_options) \
        .option("query", query) \
        .load()
    
    chunk_df_count = chunk_df.count()
    

    chunk_df = chunk_df.filter(col("quantity") > (col("adv30") * (PERCENTAGE_ADV30/10000000000)))
    
    if chunk_df_count <= 0:
        current_chunk_start += CHUNK_SIZE + TOTAL_WINDOW_NS
        continue
    total_time_to_check_empty = time.time() - time_to_check_empty
    
    
    estimated_size_mb = (chunk_df_count * 120) / (1024 ** 2)
    shuffle_partition_count = math.ceil(estimated_size_mb / 128)

    spark_cores = spark.sparkContext.defaultParallelism
    shuffle_partitions = b.max(shuffle_partition_count, spark_cores)

    spark.conf.set("spark.sql.shuffle.partitions", shuffle_partitions)

    chunk_df = chunk_df.withColumn("trigger_trx_ts_epoch",
        (expr("CAST(UNIX_TIMESTAMP(LEFT(transaction_timestamp, 19)) AS BIGINT) * 1000000") +
        expr("CAST(RIGHT(transaction_timestamp, 6) AS BIGINT)"))
    )
    
    if next_window:
        skip_overlap_start = current_chunk_start + TOTAL_WINDOW_NS
        chunk_df = chunk_df.withColumn("is_overlap",
            when(col("trigger_trx_ts_epoch") < lit(skip_overlap_start), lit(True)).otherwise(lit(False))
        )
    else:
        next_window = True
        chunk_df = chunk_df.withColumn("is_overlap", lit(False))

    chunk_df = chunk_df.withColumnRenamed( "transaction_id", "trigger_trx_id") \
                .withColumnRenamed("symbol", "trigger_symbol") \
                .withColumnRenamed( "transaction_timestamp", "trigger_trx_ts")          
    
    

    chunk_df = chunk_df.withColumn(
        "buy_volume", when(col("side_id") == 1, col("quantity")).otherwise(0).cast("int")
    ).withColumn(
        "sell_volume", when(col("side_id") == 2, col("quantity")).otherwise(0).cast("int")
    )
    
    chunk_df.persist(serialized_storage)
    
    ignition_window = Window.partitionBy("trigger_symbol") \
                        .orderBy(col("trigger_trx_ts_epoch").cast("long")) \
                        .rangeBetween(-IGNITION_WINDOW_NS - FOLLOW_THROUGH_NS - REVERSAL_WINDOW_NS, -FOLLOW_THROUGH_NS - REVERSAL_WINDOW_NS)
    
    follow_through_window = Window.partitionBy("trigger_symbol") \
                        .orderBy(col("trigger_trx_ts_epoch").cast("long")) \
                        .rangeBetween(-FOLLOW_THROUGH_NS - REVERSAL_WINDOW_NS, - REVERSAL_WINDOW_NS)
    
    reversal_window = Window.partitionBy("trigger_symbol") \
                        .orderBy(col("trigger_trx_ts_epoch").cast("long")) \
                        .rangeBetween(-REVERSAL_WINDOW_NS, 0)

    ignition_expressions = {
        "iw_avg_price": mean("price").over(ignition_window),
        "iw_std_price": stddev("price").over(ignition_window),
        "iw_avg_volume": mean("quantity").over(ignition_window),
        "iw_std_volume": stddev("quantity").over(ignition_window),
        "iw_momentum": (last("price").over(ignition_window) - first("price").over(ignition_window)) / first("price").over(ignition_window),
        "iw_last_price": last("price").over(ignition_window),
        "iw_last_volume": last("quantity").over(ignition_window),
        "iw_max_price": max("price").over(ignition_window),
        "iw_f_trx": first("trigger_trx_ts").over(ignition_window),
        "iw_l_trx": last("trigger_trx_ts").over(ignition_window),
        "iw_price_th": col("iw_avg_price") + THRES_PRICE_MOIG * col("iw_std_price"),
        "iw_volume_th": col("iw_avg_volume") + THRES_QUANTITY_MOIG * col("iw_std_volume"),
        "iw_total_buy_volume": sum(col("buy_volume").cast("int")).over(ignition_window),
        "iw_total_sell_volume": sum(col("sell_volume").cast("int")).over(ignition_window),
        "buy_sell_ratio": (col("iw_total_buy_volume") / col("iw_total_sell_volume")).cast("float")
    }
    
    
    stats_df = chunk_df.select(
        "*",  # Keep existing columns
        *[
            when(col("is_overlap") == False, expr).alias(col_name)
            for col_name, expr in ignition_expressions.items()
        ]
    )
    stats_df.persist(serialized_storage)
    chunk_df.unpersist(blocking=True)

    ignition_condition = (
        (col("iw_last_price") > col("iw_price_th")) &
        (col("iw_last_volume") > col("iw_volume_th")) &
        (col("iw_momentum") > THRES_MOMENTUM_MOIG) &
        (
            ((col("iw_total_buy_volume") > 0) & (col("iw_total_sell_volume") > 0) & (col("buy_sell_ratio") > THRES_IGNITION_BUY_SELL_RATIO)) |
            ((col("iw_total_buy_volume") > 0) & (col("iw_total_sell_volume") == 0))
        ) &
        (~col("is_overlap"))
    )

    stats_df = stats_df.withColumn("phase", when(ignition_condition, lit("Ignition"))) \
        .withColumn("price_breach_pct", when(ignition_condition, ((col("iw_last_price") - col("iw_price_th")) / col("iw_price_th")) * 100)) \
        .withColumn("volume_breach_pct", when(ignition_condition, ((col("iw_last_volume") - col("iw_volume_th")) / col("iw_volume_th")) * 100)) \
        .withColumn("momentum_breach_pct", when(ignition_condition, abs((col("iw_momentum") - (lit(THRES_MOMENTUM_MOIG) * 100)) / (lit(THRES_MOMENTUM_MOIG) * 100) * 100))) \
        .withColumn("buy_sell_ratio_breach_pct", when(ignition_condition, col("buy_sell_ratio"))
    ).select(
        "trigger_symbol", "trigger_trx_ts", "trigger_trx_id", "iw_avg_price", "iw_f_trx", "iw_l_trx", "trader_id","price",
        "iw_max_price", "trigger_trx_ts_epoch","price_breach_pct", "volume_breach_pct", "momentum_breach_pct", "buy_sell_ratio_breach_pct", "phase", "buy_volume","sell_volume"
    )
    
    
    # print("Count after aggregations: ",stats_df.count())

    ignition_symbols_df = stats_df.filter(col("phase") == "Ignition") \
                              .select("trigger_symbol") \
                              .distinct()
    
    stats_df = stats_df.join(ignition_symbols_df, on="trigger_symbol", how="inner")
    # print("After the filter of distinct Symbols: ",stats_df.count())
    
    stats_df = stats_df.withColumn("iw_trx_range", when(col("phase") == "Ignition", collect_list("trigger_trx_id").over(ignition_window))
    ).withColumn("iw_trader_ids", when(col("phase") == "Ignition", collect_set("trader_id").over(ignition_window)))
    
    
   

    ignition_df = stats_df.filter(col("phase") == "Ignition") \
        .select("trigger_symbol", "trigger_trx_ts_epoch")

    ignition_df = ignition_df.withColumn("start_epoch", col("trigger_trx_ts_epoch") - (FOLLOW_THROUGH_NS + REVERSAL_WINDOW_NS)) \
                            .withColumnRenamed("trigger_trx_ts_epoch", "ignition_epoch")
    
    joined_df = stats_df.alias("s").join(
        ignition_df.alias("i"),
        (col("s.trigger_symbol") == col("i.trigger_symbol")) &
        (col("s.trigger_trx_ts_epoch") >= col("i.start_epoch")) &
        (col("s.trigger_trx_ts_epoch") <= col("i.ignition_epoch"))
    ).select("s.*")  # Select only original stats_df columns
   
    # print(joined_df.count())
    # Optional: Remove duplicates if needed
    joined_df = joined_df.dropDuplicates()
    joined_df.persist(serialized_storage)
    stats_df.unpersist(blocking=True)
    stats_df = joined_df 
    
    # break

    

    # print("final stats_df" , stats_df.count())

    follow_through_expressions = {
        "ftw_max_price": max("price").over(follow_through_window),
        "fw_trx_range" : collect_list("trigger_trx_id").over(follow_through_window),
        "fw_trader_ids" : collect_set("trader_id").over(follow_through_window),
        "trader_id_match_count" : size(array_intersect(col("fw_trader_ids"), col("iw_trader_ids"))),
        "iw_trader_id_count" : size(col("iw_trader_ids")),
        "fw_trader_id_match_percentage" : col("trader_id_match_count") / col("iw_trader_id_count")
    }

    stats_df = stats_df.select(
        "*",  # Keep existing columns
        *[
            when(col("phase") == "Ignition", expr).alias(col_name)
            for col_name, expr in follow_through_expressions.items()
        ]
    )

    # stats_df.display()
    # print(stats_df.count())
    
    follow_through_condition = (
        (col("phase") == "Ignition") &
        (col("ftw_max_price") > col("iw_max_price")) &
        (col("fw_trader_id_match_percentage") <= lit(THRES_TRADER_ID_MATCH_PERCENTAGE_FOLLOW))
    )

    # Step 4: Add Follow-Through phase flag and mismatch percentage column
    stats_df = stats_df.withColumn("is_follow_through", when(follow_through_condition, lit("Follow-Through"))) \
                .withColumn("trader_id_mismatch_pct_follow", 
                            when(follow_through_condition, col("fw_trader_id_match_percentage") * 100)
                            ).select(
                                "trigger_symbol","trigger_trx_ts","trigger_trx_id","iw_avg_price","trader_id","price","trigger_trx_ts_epoch","price_breach_pct","volume_breach_pct","momentum_breach_pct","buy_sell_ratio_breach_pct","phase","buy_volume","sell_volume","iw_trx_range","iw_trader_ids","fw_trx_range","iw_trader_id_count","is_follow_through","trader_id_mismatch_pct_follow")
    
    # print(stats_df.count())

    follow_through_df = stats_df.filter((col("phase") == "Ignition") & (col("is_follow_through") == "Follow-Through")) \
        .select("trigger_symbol", "trigger_trx_ts_epoch")

    follow_through_df = follow_through_df.withColumn("start_epoch", col("trigger_trx_ts_epoch") - (REVERSAL_WINDOW_NS)) \
                            .withColumnRenamed("trigger_trx_ts_epoch", "ignition_epoch")

    joined_df = stats_df.alias("s").join(
        follow_through_df.alias("i"),
        (col("s.trigger_symbol") == col("i.trigger_symbol")) &
        (col("s.trigger_trx_ts_epoch") >= col("i.start_epoch")) &
        (col("s.trigger_trx_ts_epoch") <= col("i.ignition_epoch"))
    ).select("s.*")

    joined_df = joined_df.dropDuplicates()
    joined_df.persist(serialized_storage)
    stats_df.unpersist(blocking=True)
    stats_df = joined_df 
    
    reversal_expressions = {
        "rw_min_price": min("price").over(reversal_window),
        "rw_trx_range" : collect_list("trigger_trx_id").over(reversal_window),
        "rw_trader_ids" : collect_set("trader_id").over(reversal_window),
        "trader_id_match_count" : size(array_intersect(col("rw_trader_ids"), col("iw_trader_ids"))),
        "iw_trader_id_count" : size(col("iw_trader_ids")),
        "rw_trader_id_match_percentage" : col("trader_id_match_count") / col("iw_trader_id_count"),
        "rw_total_buy_volume" : sum(col("buy_volume")).over(reversal_window),
        "rw_total_sell_volume" : sum(col("sell_volume")).over(reversal_window),
        "sell_buy_ratio" : (col("rw_total_sell_volume") / col("rw_total_buy_volume"))
    }

    stats_df = stats_df.select(
        "*",  # Keep existing columns
        *[
            when((col("phase") == "Ignition") & (col("is_follow_through") == "Follow-Through"), expr).alias(col_name)
            for col_name, expr in reversal_expressions.items()
        ]
    )

    reversal_condition = (
        (col("phase") == "Ignition") &
        (col("is_follow_through") == "Follow-Through") &
        (col("rw_min_price") < col("iw_avg_price")) &
        (col("rw_trader_id_match_percentage") >= lit(THRES_TRADER_ID_MATCH_PERCENTAGE_REVERSAL)) &  # later for this there will be the threshold
        (
            ((col("rw_total_buy_volume") > 0) & (col("rw_total_sell_volume") > 0) & (col("sell_buy_ratio") > REVERSAL_THRESHOLD_SELL_BUY_RATIO)) |
            ((col("rw_total_sell_volume") > 0) & (col("rw_total_buy_volume") == 0))
        )
    )
    
    momentum_ignition_df = stats_df.filter(reversal_condition).withColumn(
                            "all_trx_range",
                            concat(col("iw_trx_range"), col("fw_trx_range"), col("rw_trx_range"))
                        ).select(
                            "trigger_symbol","trigger_trx_ts","trigger_trx_id","trigger_trx_ts_epoch","price_breach_pct","volume_breach_pct","momentum_breach_pct","buy_sell_ratio_breach_pct","all_trx_range","trader_id_mismatch_pct_follow",
                            (col("rw_trader_id_match_percentage") * 100).alias("trader_id_match_pct_reversal"),
                            col("sell_buy_ratio").alias("sell_buy_ratio_breach_pct"),
                            lit("Momentum-Ignition").alias("isMomentum")
                        )

    momentum_ignition_df.persist(serialized_storage)
    momentum_ignition_total_count += momentum_ignition_df.count()
    
    # Step 1: Join all three phases on trigger_trx_id
    breach_df = momentum_ignition_df.select(
        col("trigger_trx_ts"),
        element_at(col("all_trx_range"), -1).alias("triggering_trx"),
        col("price_breach_pct"),
        col("volume_breach_pct"),
        col("momentum_breach_pct"),
        col("buy_sell_ratio_breach_pct"),
        col("trader_id_mismatch_pct_follow"),
        col("trader_id_match_pct_reversal"),
        col("sell_buy_ratio_breach_pct")
    )

    window_spec = Window.orderBy("trigger_trx_ts")

    alert_query = """
        (
            SELECT ALERT_ID
            FROM THRESHOLD_BREACH
            WHERE ALERT_ID LIKE 'MOIG_%'
            ORDER BY TRY_TO_NUMBER(SUBSTRING(ALERT_ID, 6)) DESC
            LIMIT 1
        ) AS last_alert
    """

    last_alert_id_df = spark.read \
        .format("snowflake") \
        .options(**snowflake_options_alert) \
        .option("query", alert_query) \
        .load()
    
    last_id_row = last_alert_id_df.collect()
    if last_id_row:
        last_alert_id_str = last_id_row[0]["ALERT_ID"]
        last_num = int(last_alert_id_str.replace("MOIG_", ""))
    else:
        last_num = 0
    
    offset = last_num

    # Generate alert metadata
    alert_df = momentum_ignition_df \
        .withColumn("incremental_id", row_number().over(window_spec)) \
        .withColumn("generated_id", col("incremental_id") + offset) \
        .withColumn("alert_id", concat(lit("MOIG_"), col("generated_id"))) \
        .withColumn("triggering_transaction_internal_id", element_at(col("all_trx_range"), -1)) \
        .withColumn("trigger_trx_ts_ts", col("trigger_trx_ts").cast("timestamp")) \
        .withColumn(
        "alert_phases",
        to_json(
            array(
                struct(
                    lit("Ignition Phase").alias("phase_name"),
                    (col("trigger_trx_ts_ts") - expr(f"INTERVAL {IGNITION_WINDOW + FOLLOW_THROUGH + REVERSAL_WINDOW} MINUTES")).alias("start_date"),
                    (col("trigger_trx_ts_ts") - expr(f"INTERVAL {FOLLOW_THROUGH + REVERSAL_WINDOW} MINUTES")).alias("end_date")
                ),
                struct(
                    lit("Follow-Through Phase").alias("phase_name"),
                    (col("trigger_trx_ts_ts") - expr(f"INTERVAL {FOLLOW_THROUGH + REVERSAL_WINDOW} MINUTES")).alias("start_date"),
                    (col("trigger_trx_ts_ts") - expr(f"INTERVAL {REVERSAL_WINDOW} MINUTES")).alias("end_date")
                ),
                struct(
                    lit("Reversal Phase").alias("phase_name"),
                    (col("trigger_trx_ts_ts") - expr(f"INTERVAL {REVERSAL_WINDOW} MINUTES")).alias("start_date"),
                    col("trigger_trx_ts_ts").alias("end_date")
                )
            )
        )
    ) \
        .filter(col("triggering_transaction_internal_id").isNotNull()) \
        .select(
            "alert_id", "trigger_trx_ts",
            lit(1).alias("asset_analytics_id"),
            lit(0.9).alias("alert_score"),
            lit("High").alias("alert_severity"),
            lit(1).alias("alert_status_id"),
            "triggering_transaction_internal_id",
            current_timestamp().alias("creation_timestamp"),
            current_timestamp().alias("last_update_timestamp"),
            "all_trx_range",
            "alert_phases"
    )
    
    joined_df = breach_df.alias("b").join(
        alert_df.alias("a"),
        col("b.triggering_trx") == col("a.triggering_transaction_internal_id"),
        "inner"
    )

    # Map DataFrame columns to threshold IDs
    column_to_threshold = {
        'price_breach_pct': threshold_id_map['MOIG_THRESHOLD_PRICE'],
        'volume_breach_pct': threshold_id_map['MOIG_THRESHOLD_VOLUME'],
        'momentum_breach_pct': threshold_id_map['MOIG_THRESHOLD_MOMENTUM'],
        'buy_sell_ratio_breach_pct': threshold_id_map['THRES_IGNITION_BUY_SELL_RATIO'],
        'trader_id_mismatch_pct_follow': threshold_id_map['THRES_TRADER_ID_MATCH_PERCENTAGE_FOLLOW'],
        'trader_id_match_pct_reversal': threshold_id_map['THRES_TRADER_ID_MATCH_PERCENTAGE_REVERSAL'],
        'sell_buy_ratio_breach_pct': threshold_id_map['REVERSAL_THRESHOLD_SELL_BUY_RATIO'],
        'ignition_window': threshold_id_map['IGNITION_WINDOW'],
        'reversal_window': threshold_id_map['FOLLOW_THROUGH'],
        'reversal_window': threshold_id_map['REVERSAL_WINDOW'],
        'ADV30_breach' : threshold_id_map['PERCENTAGE_ADV30']
    }
    
    threshold_mapping = [
        (metric, threshold_id) for metric, threshold_id in column_to_threshold.items()
    ]
    threshold_df = spark.createDataFrame(threshold_mapping, ["metric", "threshold_id"])
    

    alert_breach_df = joined_df.select(
        col("a.alert_id"),
        array(
            struct(lit("price_breach_pct").alias("metric"), col("b.price_breach_pct").alias("value")),
            struct(lit("volume_breach_pct").alias("metric"), col("b.volume_breach_pct").alias("value")),
            struct(lit("momentum_breach_pct").alias("metric"), col("b.momentum_breach_pct").alias("value")),
            struct(lit("buy_sell_ratio_breach_pct").alias("metric"), col("b.buy_sell_ratio_breach_pct").alias("value")),
            struct(lit("trader_id_mismatch_pct_follow").alias("metric"), col("b.trader_id_mismatch_pct_follow").alias("value")),
            struct(lit("trader_id_match_pct_reversal").alias("metric"), col("b.trader_id_match_pct_reversal").alias("value")),
            struct(lit("sell_buy_ratio_breach_pct").alias("metric"), col("b.sell_buy_ratio_breach_pct").alias("value")),
            struct(lit("ignition_window").alias("metric"), lit(None).cast("string").alias("value")),
            struct(lit("reversal_window").alias("metric"), lit(None).cast("string").alias("value")),
            struct(lit("reversal_window").alias("metric"), lit(None).cast("string").alias("value")),
            struct(lit("ADV30_breach").alias("metric"), lit(None).cast("string").alias("value"))
        ).alias("breach_metrics")
    )
    
    exploded_df = alert_breach_df.select(
        col("alert_id"),
        explode(col("breach_metrics")).alias("metric_struct")
    ).select(
        col("alert_id"),
        col("metric_struct.metric").alias("metric"),
        col("metric_struct.value").alias("breach_value")
    )

    final_df = exploded_df.join(
        threshold_df,
        exploded_df.metric == threshold_df.metric,
        "left"
    ).select(
        col("alert_id"),
        col("threshold_id"),
        col("breach_value")
    )

    # Create alerted_df by excluding the triggering transaction
    alerted_df = alert_df.withColumn(
        "remaining_transactions",
        slice(col("all_trx_range"), 1, size(col("all_trx_range")))  # Take all except the last
    ).withColumn("transaction_internal_id", explode(col("remaining_transactions"))) \
    .select("alert_id", "transaction_internal_id", "creation_timestamp", "last_update_timestamp")

    # Remove all_trx_range from alert_df as it's no longer needed
    alert_df = alert_df.drop("all_trx_range")

    alert_df = alert_df.select(
        "alert_id",
        "asset_analytics_id",
        "alert_score",
        "alert_severity",
        "alert_status_id",
        "triggering_transaction_internal_id",
        "creation_timestamp",
        "last_update_timestamp",
        "alert_phases"
    )

    alert_df.persist(serialized_storage)
    alert_df.display()
    alert_df.write.format("snowflake") \
    .options(**snowflake_options_alert) \
    .option("dbtable", "ALERT") \
    .mode("append") \
    .save()

    alerted_df.write.format("snowflake") \
    .options(**snowflake_options_alert) \
    .option("dbtable", "ALERTED_TRANSACTION") \
    .mode("append") \
    .save()

    final_df.write.format("snowflake") \
    .options(**snowflake_options_alert) \
    .option("dbtable", "THRESHOLD_BREACH") \
    .mode("append") \
    .save()

    chunk_df.unpersist()
    stats_df.unpersist()
    # ignition_df.unpersist()
    # follow_through_df.unpersist()
    # reversal_df.unpersist()
    alert_df.unpersist()

    print(f"Time Frame in consideration: {datetime.fromtimestamp(int(current_chunk_start) // 10 ** 6).strftime('%H:%M')} to {datetime.fromtimestamp(int(current_chunk_end) // 10 ** 6).strftime('%H:%M')}")

    current_chunk_start = current_chunk_end - TOTAL_WINDOW_NS

print(f"Total Momentum Ignition Events: {momentum_ignition_total_count}")

alert_id,asset_analytics_id,alert_score,alert_severity,alert_status_id,triggering_transaction_internal_id,creation_timestamp,last_update_timestamp,alert_phases


Time Frame in consideration: 09:30 to 10:30
Total Momentum Ignition Events: 0


In [0]:
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution Time MOIG: {execution_time:.6f} seconds")

Execution Time MOIG: 38.003679 seconds
