In [0]:
# Notebook 1: Create Worklist with Data - CDF Incremental with Trust Filter
# Creates a worklist table with actual blob data for efficient processing

from pyspark.sql import functions as F
from pyspark.sql.window import Window
from datetime import datetime
import time

# Configuration
SOURCE_TABLE = "4_prod.raw.mill_ce_blob"
TARGET_TABLE = "4_prod.bronze.mill_blob_text"
STAGING_DB = "4_prod.tmp"
RUN_ID = datetime.now().strftime("%Y%m%d_%H%M%S")
WORKLIST_TABLE = f"{STAGING_DB}.blob_worklist_{RUN_ID}"
MAX_BLOB_SIZE = 16 * 1024 * 1024  # 16 MB
USE_CDF = True  # Enable CDF
RUN_TS = datetime.now()

print("="*80)
print(f"CREATING WORKLIST: {WORKLIST_TABLE}")
print(f"Run ID: {RUN_ID}")
print("="*80)

#=========================
# Helper Functions
#=========================

def get_last_processed_timestamp(target_table: str):
    """Get the last ADC_UPDT timestamp from target table with 10-minute safety margin"""
    try:
        result = spark.sql(f"""
            SELECT COALESCE(MAX(ADC_UPDT), timestamp('1970-01-01T00:00:00Z')) AS max_ts
            FROM {target_table}
        """).collect()
        
        if result and result[0]['max_ts']:
            max_ts = result[0]['max_ts']
            # Apply 10-minute safety margin to catch any late-arriving updates
            safe_ts = spark.sql(f"""
                SELECT timestamp('{max_ts}') - INTERVAL 10 MINUTES AS safe_ts
            """).collect()[0]['safe_ts']
            return safe_ts
        else:
            return datetime(1970, 1, 1)
    except Exception as e:
        print(f"  Warning: Could not get last timestamp from target: {str(e)}")
        return datetime(1970, 1, 1)

def get_changes_via_cdf(source_table: str, start_ts, trust_filter: str = "Barts"):
    """
    Try to get changed EVENT_IDs using CDF.
    Returns DataFrame with (EVENT_ID, ADC_UPDT) or None if CDF fails.
    """
    try:
        print(f"  Attempting CDF query from {start_ts}...")
        
        # Get latest version
        history = spark.sql(f"DESCRIBE HISTORY {source_table}")
        end_version = history.selectExpr("MAX(version) as v").collect()[0]['v']
        
        if end_version is None:
            print("  No history found, falling back to time-window")
            return None
        
        # Find start version corresponding to timestamp
        start_version_row = spark.sql(f"""
            SELECT MIN(version) AS v
            FROM (DESCRIBE HISTORY {source_table})
            WHERE timestamp >= timestamp('{start_ts}')
        """).collect()[0]
        
        start_version = start_version_row['v'] if start_version_row and start_version_row['v'] is not None else end_version
        
        print(f"  CDF version range: {start_version} to {end_version}")
        
        # Detect schema change boundaries
        boundary_rows = spark.sql(f"""
            SELECT version
            FROM (DESCRIBE HISTORY {source_table})
            WHERE version BETWEEN {start_version} AND {end_version}
              AND (
                    operation IN ('SET TBLPROPERTIES','REPLACE','REPLACE TABLE',
                                  'ADD COLUMNS','CHANGE COLUMN','ALTER TABLE ADD COLUMNS',
                                  'ALTER TABLE CHANGE COLUMN','CONVERT TO DELTA','RESTORE')
                 OR instr(upper(CAST(operationParameters AS STRING)), 'COLUMNMAPPING') > 0
                 OR instr(upper(CAST(operationParameters AS STRING)), 'SCHEMA') > 0
              )
            ORDER BY version
        """).collect()
        
        boundaries = [r['version'] for r in boundary_rows]
        
        # Build segments to avoid schema change issues
        edges = [start_version] + boundaries + [end_version + 1]
        segments = []
        
        for i in range(len(edges) - 1):
            s_ver = edges[i]
            e_ver = edges[i+1] - 1
            if s_ver <= e_ver:
                segments.append(f"""
                    SELECT EVENT_ID, ADC_UPDT, _commit_timestamp
                    FROM table_changes('{source_table}', {s_ver}, {e_ver})
                    WHERE _change_type IN ('insert', 'update_postimage')
                      AND Trust = '{trust_filter}'
                """)
        
        if not segments:
            print("  No CDF segments to read, falling back to time-window")
            return None
        
        # Union all segments
        union_sql = " UNION ALL ".join(segments)
        cdf_query = f"""
            SELECT EVENT_ID, ADC_UPDT, MAX(_commit_timestamp) AS _ch_ts
            FROM ({union_sql}) ch
            GROUP BY EVENT_ID, ADC_UPDT
        """
        
        changes_df = spark.sql(cdf_query).select("EVENT_ID", "ADC_UPDT").distinct()
        change_count = changes_df.count()
        print(f"  CDF found {change_count:,} changed EVENT_IDs")
        
        return changes_df
        
    except Exception as e:
        print(f"  CDF query failed: {str(e)[:500]}")
        print("  Falling back to time-window scan")
        return None

def get_changes_via_timewindow(source_table: str, start_ts, trust_filter: str = "Barts"):
    """
    Fallback: Get changed EVENT_IDs by scanning ADC_UPDT >= start_ts
    """
    print(f"  Using time-window scan from {start_ts}...")
    
    changes_df = spark.sql(f"""
        SELECT DISTINCT EVENT_ID, ADC_UPDT
        FROM {source_table}
        WHERE ADC_UPDT >= timestamp('{start_ts}')
          AND Trust = '{trust_filter}'
    """)
    
    change_count = changes_df.count()
    print(f"  Time-window scan found {change_count:,} changed EVENT_IDs")
    
    return changes_df

#=========================
# Main Processing Logic
#=========================

# Step 1: Get last processed timestamp
print(f"[{datetime.now().strftime('%H:%M:%S')}] Getting last processed timestamp...")
last_processed_ts = get_last_processed_timestamp(TARGET_TABLE)
print(f"Last processed timestamp: {last_processed_ts}")

# Step 2: Get changed EVENT_IDs (CDF with fallback)
print(f"[{datetime.now().strftime('%H:%M:%S')}] Detecting changes...")

if USE_CDF:
    changed_events = get_changes_via_cdf(SOURCE_TABLE, last_processed_ts, trust_filter="Barts")
    if changed_events is None:
        changed_events = get_changes_via_timewindow(SOURCE_TABLE, last_processed_ts, trust_filter="Barts")
else:
    print("  CDF disabled, using time-window scan")
    changed_events = get_changes_via_timewindow(SOURCE_TABLE, last_processed_ts, trust_filter="Barts")

# Step 3: Filter out events already processed with same ADC_UPDT
print(f"[{datetime.now().strftime('%H:%M:%S')}] Filtering already processed events...")
new_events = (changed_events
    .join(
        spark.table(TARGET_TABLE).select("EVENT_ID", "ADC_UPDT").distinct(),
        on=["EVENT_ID", "ADC_UPDT"],
        how="left_anti"
    )
    .limit(250000))

new_event_count = new_events.count()
print(f"Found {new_event_count:,} new Barts events to process")

if new_event_count == 0:
    print("No new events to process!")
    dbutils.notebook.exit("NO_WORK")

# Step 4: Build worklist with blob data
print(f"[{datetime.now().strftime('%H:%M:%S')}] Building worklist with blob data...")

# Define columns we need
META_AND_BLOB_COLS = [
    "EVENT_ID", "BLOB_SEQ_NUM",
    "VALID_UNTIL_DT_TM", "VALID_FROM_DT_TM",
    "UPDT_DT_TM", "UPDT_ID", "UPDT_TASK", "UPDT_CNT", "UPDT_APPLCTX",
    "LAST_UTC_TS", "ADC_UPDT", "COMPRESSION_CD", "BLOB_CONTENTS", "BLOB_LENGTH"
]

# Load source data for new events (with Trust filter)
source_data = (spark.table(SOURCE_TABLE)
               .filter(F.col("Trust") == "Barts")
               .join(F.broadcast(new_events), on=["EVENT_ID", "ADC_UPDT"], how="inner")
               .select(*META_AND_BLOB_COLS))

# Deduplicate using window function (keep most recent version)
w_temporal = Window.partitionBy("EVENT_ID", "BLOB_SEQ_NUM").orderBy(
    F.col("VALID_UNTIL_DT_TM").desc(),
    F.col("UPDT_DT_TM").desc(),
    F.col("LAST_UTC_TS").desc()
)

deduped_data = (source_data
                .withColumn("version_rank", F.row_number().over(w_temporal))
                .filter(F.col("version_rank") == 1)
                .drop("version_rank"))

# Aggregate chunks by EVENT_ID and ADC_UPDT
worklist_with_meta = (deduped_data
    .withColumn("chunk_size", F.coalesce(F.col("BLOB_LENGTH").cast("long"), F.lit(0)))
    .groupBy("EVENT_ID", "ADC_UPDT")
    .agg(
        F.sum("chunk_size").alias("compressed_size"),
        F.count("*").alias("chunk_count"),
        F.collect_list(F.struct(*[c for c in META_AND_BLOB_COLS if c != "EVENT_ID"])).alias("chunks_data")
    )
    .withColumn("status", 
                F.when(F.col("compressed_size") > MAX_BLOB_SIZE, "oversized")
                .otherwise("pending"))
    .withColumn("batch_num", F.lit(0))
    .withColumn("process_ts", F.lit(None).cast("timestamp"))
    .withColumn("error_msg", F.lit(None).cast("string")))

# Write worklist table
print(f"[{datetime.now().strftime('%H:%M:%S')}] Writing worklist table...")
worklist_with_meta.write.mode("overwrite").saveAsTable(WORKLIST_TABLE)

# Optimize for efficient querying
spark.sql(f"OPTIMIZE {WORKLIST_TABLE} ZORDER BY (ADC_UPDT, EVENT_ID)")

# Get statistics
stats = spark.table(WORKLIST_TABLE).groupBy("status").count().collect()
stats_dict = {row["status"]: row["count"] for row in stats}

print(f"\n[{datetime.now().strftime('%H:%M:%S')}] Worklist created successfully:")
print(f"  - Pending: {stats_dict.get('pending', 0):,}")
print(f"  - Oversized: {stats_dict.get('oversized', 0):,}")
print(f"  - Total chunks stored: {spark.table(WORKLIST_TABLE).select(F.sum(F.col('chunk_count'))).collect()[0][0]:,}")
print(f"  - Trust filter: Barts only")
print(f"  - Incremental from: {last_processed_ts}")
print(f"  - CDF enabled: {USE_CDF}")

# Step 5: Handle oversized events
if stats_dict.get('oversized', 0) > 0:
    print(f"\n[{datetime.now().strftime('%H:%M:%S')}] Writing {stats_dict.get('oversized', 0)} oversized events to target...")
    
    oversized = (spark.table(WORKLIST_TABLE)
                 .filter(F.col("status") == "oversized")
                 .select(
                     "EVENT_ID",
                     "ADC_UPDT",
                     "compressed_size",
                     F.col("chunks_data")[0].alias("first_chunk")
                 ))
    
    oversized_output = oversized.select(
        "EVENT_ID",
        F.col("first_chunk.VALID_UNTIL_DT_TM").alias("VALID_UNTIL_DT_TM"),
        F.col("first_chunk.VALID_FROM_DT_TM").alias("VALID_FROM_DT_TM"),
        F.col("first_chunk.UPDT_DT_TM").alias("UPDT_DT_TM"),
        F.col("first_chunk.UPDT_ID").cast("long").alias("UPDT_ID"),
        F.col("first_chunk.UPDT_TASK").cast("long").alias("UPDT_TASK"),
        F.col("first_chunk.UPDT_CNT").cast("long").alias("UPDT_CNT"),
        F.col("first_chunk.UPDT_APPLCTX").cast("long").alias("UPDT_APPLCTX"),
        F.col("first_chunk.LAST_UTC_TS").alias("LAST_UTC_TS"),
        "ADC_UPDT",
        F.lit(None).cast("binary").alias("BLOB_BINARY"),
        F.lit(None).cast("string").alias("CONTENT_TYPE"),
        F.lit(None).cast("string").alias("ENCODING"),
        F.lit(None).cast("string").alias("BLOB_TEXT"),
        F.col("compressed_size").alias("BINARY_SIZE"),
        F.lit(None).cast("long").alias("TEXT_LENGTH"),
        F.concat(F.lit("Compressed Too Large: "), F.col("compressed_size"), F.lit(" bytes")).alias("STATUS"),
        F.lit(None).cast("string").alias("anon_text")
    )
    
    oversized_output.write.mode("append").insertInto(TARGET_TABLE)
    print(f"  Wrote {stats_dict.get('oversized', 0)} oversized events to target")

# Step 6: Create/update metadata table with schema evolution
print(f"\n[{datetime.now().strftime('%H:%M:%S')}] Updating metadata table...")

METADATA_TABLE = f"{STAGING_DB}.pipeline_metadata"

# Check if table exists and get its columns
try:
    existing_columns = set([f.name for f in spark.table(METADATA_TABLE).schema.fields])
    table_exists = True
    print(f"  Existing metadata table has columns: {existing_columns}")
except:
    existing_columns = set()
    table_exists = False
    print("  Creating new metadata table")

# Create table if it doesn't exist
if not table_exists:
    spark.sql(f"""
        CREATE TABLE {METADATA_TABLE} (
            run_id STRING,
            worklist_table STRING,
            total_events INT,
            pending_events INT,
            oversized_events INT,
            created_ts TIMESTAMP,
            completed_ts TIMESTAMP,
            merged_ts TIMESTAMP,
            status STRING,
            batch_tables STRING,
            processed_events INT,
            trust_filter STRING,
            cdf_enabled BOOLEAN,
            last_processed_ts TIMESTAMP
        ) USING DELTA
    """)
    existing_columns = {'run_id', 'worklist_table', 'total_events', 'pending_events', 
                       'oversized_events', 'created_ts', 'completed_ts', 'merged_ts', 
                       'status', 'batch_tables', 'processed_events', 'trust_filter', 
                       'cdf_enabled', 'last_processed_ts'}

# Add missing columns if table exists
if table_exists:
    if 'trust_filter' not in existing_columns:
        spark.sql(f"ALTER TABLE {METADATA_TABLE} ADD COLUMN trust_filter STRING")
        print("  Added column: trust_filter")
    if 'cdf_enabled' not in existing_columns:
        spark.sql(f"ALTER TABLE {METADATA_TABLE} ADD COLUMN cdf_enabled BOOLEAN")
        print("  Added column: cdf_enabled")
    if 'last_processed_ts' not in existing_columns:
        spark.sql(f"ALTER TABLE {METADATA_TABLE} ADD COLUMN last_processed_ts TIMESTAMP")
        print("  Added column: last_processed_ts")

# Insert metadata
spark.sql(f"""
    INSERT INTO {METADATA_TABLE}
    VALUES (
        '{RUN_ID}',
        '{WORKLIST_TABLE}',
        {new_event_count},
        {stats_dict.get('pending', 0)},
        {stats_dict.get('oversized', 0)},
        current_timestamp(),
        NULL,
        NULL,
        'worklist_created',
        NULL,
        NULL,
        'Barts',
        {USE_CDF},
        timestamp('{last_processed_ts}')
    )
""")

print(f"\nPipeline metadata saved to: {METADATA_TABLE}")
print(f"Worklist table: {WORKLIST_TABLE}")
print(f"Run ID for next notebooks: {RUN_ID}")
print("="*80)