In [0]:
# Notebook 3: Merge Batch Tables and Cleanup
# Works with the shard-based processing from the optimized pipeline
# Ensures ALL events end up in target table

from pyspark.sql import functions as F
from datetime import datetime
import time
from functools import reduce

# Configuration
TARGET_TABLE = "4_prod.bronze.mill_blob_text"
METRICS_TABLE = "4_prod.logs.mill_blob_metrics"
STAGING_DB = "4_prod.tmp"
METADATA_TABLE = f"{STAGING_DB}.pipeline_metadata"

# Get the latest completed run from metadata
latest_run = spark.sql(f"""
    SELECT run_id, worklist_table, batch_tables, processed_events
    FROM {METADATA_TABLE}
    WHERE status = 'processing_complete'
    ORDER BY created_ts DESC
    LIMIT 1
""").collect()

if not latest_run:
    raise Exception("No completed processing found! Run the combined processor notebook first.")

RUN_ID = latest_run[0]['run_id']
WORKLIST_TABLE = latest_run[0]['worklist_table']
batch_tables_str = latest_run[0]['batch_tables'] or ""
total_events = latest_run[0]['processed_events'] or 0

print("="*80)
print(f"MERGE AND CLEANUP - RUN {RUN_ID}")
print("="*80)
print(f"Worklist: {WORKLIST_TABLE}")

# Parse batch tables from metadata
if batch_tables_str:
    batch_tables = [t.strip() for t in batch_tables_str.split(",") if t.strip()]
else:
    # If not in metadata, try to find them by pattern
    print("Batch tables not in metadata, searching by pattern...")
    all_tables = spark.sql(f"SHOW TABLES IN {STAGING_DB}").collect()
    pattern = f"batch_{RUN_ID}_shard_"
    batch_tables = [f"{STAGING_DB}.{row.tableName}" 
                   for row in all_tables 
                   if row.tableName.startswith(pattern)]

if not batch_tables:
    print("WARNING: No batch tables found!")
else:
    print(f"Found {len(batch_tables)} batch tables:")
    for table in batch_tables:
        try:
            count = spark.table(table).count()
            print(f"  {table}: {count:,} records")
        except Exception as e:
            print(f"  {table}: ERROR - {e}")

# Step 1: Get oversized events that were already written
print(f"\n[{datetime.now().strftime('%H:%M:%S')}] Checking for oversized events...")
oversized_count = spark.sql(f"""
    SELECT COUNT(*) as cnt
    FROM {WORKLIST_TABLE}
    WHERE status = 'oversized'
""").collect()[0]['cnt']

if oversized_count > 0:
    print(f"  {oversized_count:,} oversized events were already written to target in Notebook 1")

# Step 2: Find events that were in batch tables
if batch_tables:
    # Build union query to get all processed EVENT_IDs
    processed_events_query = " UNION ALL ".join([f"SELECT DISTINCT EVENT_ID FROM {t}" for t in batch_tables])
    
    spark.sql(f"""
        CREATE OR REPLACE TEMPORARY VIEW processed_events AS
        {processed_events_query}
    """)
    
    processed_count = spark.sql("SELECT COUNT(DISTINCT EVENT_ID) as cnt FROM processed_events").collect()[0]['cnt']
    print(f"\n[{datetime.now().strftime('%H:%M:%S')}] Found {processed_count:,} events in batch tables")
else:
    processed_count = 0
    print(f"\n[{datetime.now().strftime('%H:%M:%S')}] No events in batch tables")

# Step 3: Find any events that didn't get processed
print(f"\n[{datetime.now().strftime('%H:%M:%S')}] Checking for unprocessed events...")

if batch_tables:
    unprocessed = spark.sql(f"""
        SELECT w.EVENT_ID, w.ADC_UPDT, w.chunks_data, w.status
        FROM {WORKLIST_TABLE} w
        WHERE w.status NOT IN ('oversized', 'completed')
        AND w.EVENT_ID NOT IN (SELECT EVENT_ID FROM processed_events)
    """)
else:
    unprocessed = spark.sql(f"""
        SELECT w.EVENT_ID, w.ADC_UPDT, w.chunks_data, w.status
        FROM {WORKLIST_TABLE} w
        WHERE w.status NOT IN ('oversized', 'completed')
    """)

unprocessed_count = unprocessed.count()

if unprocessed_count > 0:
    print(f"  Found {unprocessed_count:,} unprocessed events - writing with failed status...")
    
    # Create failure records for unprocessed events
    failed_records = (unprocessed
        .select(
            "EVENT_ID",
            F.element_at(F.col("chunks_data"), 1).alias("first_chunk")
        )
        .select(
            F.col("EVENT_ID").cast("long"),
            F.col("first_chunk.VALID_UNTIL_DT_TM").alias("VALID_UNTIL_DT_TM"),
            F.col("first_chunk.VALID_FROM_DT_TM").alias("VALID_FROM_DT_TM"),
            F.col("first_chunk.UPDT_DT_TM").alias("UPDT_DT_TM"),
            F.col("first_chunk.UPDT_ID").cast("long").alias("UPDT_ID"),
            F.col("first_chunk.UPDT_TASK").cast("long").alias("UPDT_TASK"),
            F.col("first_chunk.UPDT_CNT").cast("long").alias("UPDT_CNT"),
            F.col("first_chunk.UPDT_APPLCTX").cast("long").alias("UPDT_APPLCTX"),
            F.col("first_chunk.LAST_UTC_TS").alias("LAST_UTC_TS"),
            F.col("first_chunk.ADC_UPDT").alias("ADC_UPDT"),
            F.lit(None).cast("binary").alias("BLOB_BINARY"),
            F.lit(None).cast("string").alias("CONTENT_TYPE"),
            F.lit(None).cast("string").alias("ENCODING"),
            F.lit(None).cast("string").alias("BLOB_TEXT"),
            F.lit(None).cast("long").alias("BINARY_SIZE"),
            F.lit(None).cast("long").alias("TEXT_LENGTH"),
            F.lit("Pipeline failed - not processed").alias("STATUS"),
            F.lit(None).cast("string").alias("anon_text")
        ))
    
    # Write failed records directly to target
    failed_records.write.mode("append").insertInto(TARGET_TABLE)
    print(f"  Wrote {unprocessed_count:,} failed records to target")
else:
    print("  All events were processed successfully")

# Step 4: Extract and save metrics (if batch tables exist and metrics table exists)
if batch_tables:
    try:
        # Check if metrics table exists
        spark.sql(f"DESCRIBE TABLE {METRICS_TABLE}")
        
        print(f"\n[{datetime.now().strftime('%H:%M:%S')}] Extracting metrics from {len(batch_tables)} batch tables...")
        
        metrics_dfs = []
        for table in batch_tables:
            try:
                # Check if metrics column exists
                columns = [col.name for col in spark.table(table).schema]
                if "metrics" in columns:
                    metrics_df = (spark.table(table)
                                  .filter(F.col("metrics").isNotNull())
                                  .select(
                                      "EVENT_ID",
                                      "ADC_UPDT",
                                      "STATUS",
                                      F.col("metrics"),
                                      F.lit(RUN_ID).alias("RUN_ID"),
                                      F.current_timestamp().alias("process_ts")
                                  ))
                    metrics_dfs.append(metrics_df)
            except Exception as e:
                print(f"  Warning: Could not extract metrics from {table}: {e}")
        
        if metrics_dfs:
            combined_metrics = reduce(lambda a, b: a.union(b), metrics_dfs)
            combined_metrics = combined_metrics.select(
                "EVENT_ID", "RUN_ID", "ADC_UPDT", "STATUS", "metrics", "process_ts"
            )
            combined_metrics.write.mode("append").insertInto(METRICS_TABLE)
            metrics_count = combined_metrics.count()
            print(f"  Saved {metrics_count:,} metrics records")
    except Exception as e:
        print(f"\n[{datetime.now().strftime('%H:%M:%S')}] Metrics table not found or error: {e}")

# Step 5: Merge batch tables to target
if batch_tables:
    print(f"\n[{datetime.now().strftime('%H:%M:%S')}] Merging {len(batch_tables)} batch tables to {TARGET_TABLE}...")
    merge_start = time.time()
    
    # Define output columns based on target table schema
    output_columns = [
        "EVENT_ID", "VALID_UNTIL_DT_TM", "VALID_FROM_DT_TM", "UPDT_DT_TM",
        "UPDT_ID", "UPDT_TASK", "UPDT_CNT", "UPDT_APPLCTX",
        "LAST_UTC_TS", "ADC_UPDT", "BLOB_BINARY", "CONTENT_TYPE",
        "ENCODING", "BLOB_TEXT", "BINARY_SIZE", "TEXT_LENGTH",
        "STATUS", "anon_text"
    ]
    
    # Collect valid batch tables and their dataframes
    valid_dfs = []
    for table in batch_tables:
        try:
            df = spark.table(table)
            # Ensure EVENT_ID is long type to match target
            df = df.withColumn("EVENT_ID", F.col("EVENT_ID").cast("long"))
            # Select only columns that exist in both source and target
            existing_columns = df.columns
            select_columns = [col for col in output_columns if col in existing_columns]
            
            # Add missing columns as nulls
            for col in output_columns:
                if col not in existing_columns:
                    if col == "EVENT_ID":
                        df = df.withColumn(col, F.lit(None).cast("long"))
                    elif col in ["UPDT_ID", "UPDT_TASK", "UPDT_CNT", "UPDT_APPLCTX", "BINARY_SIZE", "TEXT_LENGTH"]:
                        df = df.withColumn(col, F.lit(None).cast("long"))
                    elif col in ["VALID_UNTIL_DT_TM", "VALID_FROM_DT_TM", "UPDT_DT_TM", "LAST_UTC_TS", "ADC_UPDT"]:
                        df = df.withColumn(col, F.lit(None).cast("timestamp"))
                    elif col == "BLOB_BINARY":
                        df = df.withColumn(col, F.lit(None).cast("binary"))
                    else:
                        df = df.withColumn(col, F.lit(None).cast("string"))
            
            valid_dfs.append(df.select(*output_columns))
            print(f"  Added {table} to merge")
        except Exception as e:
            print(f"  WARNING: Skipping {table} due to error: {e}")
    
    if valid_dfs:
        if len(valid_dfs) == 1:
            final_df = valid_dfs[0]
        else:
            final_df = reduce(lambda a, b: a.union(b), valid_dfs)
        
        # Repartition for optimal write performance
        record_count = final_df.count()
        optimal_partitions = max(1, min(200, record_count // 5000))
        
        (final_df
         .repartition(optimal_partitions)
         .write
         .mode("append")
         .option("mergeSchema", "false")
         .option("optimizeWrite", "true")
         .insertInto(TARGET_TABLE))
        
        merge_time = time.time() - merge_start
        print(f"  Merged {record_count:,} records in {merge_time:.1f}s")
else:
    print(f"\n[{datetime.now().strftime('%H:%M:%S')}] No batch tables to merge")

# Step 6: Final validation - ensure ALL events from worklist are in target
print(f"\n[{datetime.now().strftime('%H:%M:%S')}] Final validation...")

# Count worklist events (excluding duplicates)
worklist_total = spark.sql(f"""
    SELECT COUNT(DISTINCT EVENT_ID) as cnt
    FROM {WORKLIST_TABLE}
""").collect()[0]['cnt']

# Count how many of those events are now in target
target_total = spark.sql(f"""
    SELECT COUNT(DISTINCT t.EVENT_ID) as cnt
    FROM {TARGET_TABLE} t
    INNER JOIN {WORKLIST_TABLE} w ON t.EVENT_ID = w.EVENT_ID
""").collect()[0]['cnt']

print(f"  Worklist events: {worklist_total:,}")
print(f"  Events now in target: {target_total:,}")

missing_count = worklist_total - target_total

if missing_count > 0:
    print(f"\nWARNING: {missing_count} events still missing from target table!")
    
    # Force write any remaining missing events
    print("Force-writing missing events with error status...")
    
    missing_events = spark.sql(f"""
        SELECT DISTINCT w.EVENT_ID, w.chunks_data
        FROM {WORKLIST_TABLE} w
        LEFT ANTI JOIN {TARGET_TABLE} t ON w.EVENT_ID = t.EVENT_ID
    """)
    
    forced_records = (missing_events
        .select(
            F.col("EVENT_ID").cast("long"),
            F.element_at(F.col("chunks_data"), 1).alias("first_chunk")
        )
        .select(
            "EVENT_ID",
            F.col("first_chunk.VALID_UNTIL_DT_TM").alias("VALID_UNTIL_DT_TM"),
            F.col("first_chunk.VALID_FROM_DT_TM").alias("VALID_FROM_DT_TM"),
            F.col("first_chunk.UPDT_DT_TM").alias("UPDT_DT_TM"),
            F.col("first_chunk.UPDT_ID").cast("long").alias("UPDT_ID"),
            F.col("first_chunk.UPDT_TASK").cast("long").alias("UPDT_TASK"),
            F.col("first_chunk.UPDT_CNT").cast("long").alias("UPDT_CNT"),
            F.col("first_chunk.UPDT_APPLCTX").cast("long").alias("UPDT_APPLCTX"),
            F.col("first_chunk.LAST_UTC_TS").alias("LAST_UTC_TS"),
            F.col("first_chunk.ADC_UPDT").alias("ADC_UPDT"),
            F.lit(None).cast("binary").alias("BLOB_BINARY"),
            F.lit(None).cast("string").alias("CONTENT_TYPE"),
            F.lit(None).cast("string").alias("ENCODING"),
            F.lit(None).cast("string").alias("BLOB_TEXT"),
            F.lit(None).cast("long").alias("BINARY_SIZE"),
            F.lit(None).cast("long").alias("TEXT_LENGTH"),
            F.lit("Force-written: processing incomplete").alias("STATUS"),
            F.lit(None).cast("string").alias("anon_text")
        ))
    
    forced_records.write.mode("append").insertInto(TARGET_TABLE)
    print(f"  Force-wrote {missing_count} missing events")
    
    # Re-validate
    target_total = spark.sql(f"""
        SELECT COUNT(DISTINCT t.EVENT_ID) as cnt
        FROM {TARGET_TABLE} t
        INNER JOIN {WORKLIST_TABLE} w ON t.EVENT_ID = w.EVENT_ID
    """).collect()[0]['cnt']

# Step 7: Cleanup staging tables (optional - can be disabled for debugging)
CLEANUP_ENABLED = True  # Set to False to keep staging tables for debugging

if CLEANUP_ENABLED:
    print(f"\n[{datetime.now().strftime('%H:%M:%S')}] Cleaning up staging tables...")
    
    # Drop worklist table
    try:
        spark.sql(f"DROP TABLE IF EXISTS {WORKLIST_TABLE}")
        print(f"  Dropped {WORKLIST_TABLE}")
    except Exception as e:
        print(f"  Warning: Could not drop {WORKLIST_TABLE}: {e}")
    
    # Drop batch tables
    for table in batch_tables:
        try:
            spark.sql(f"DROP TABLE IF EXISTS {table}")
            print(f"  Dropped {table}")
        except Exception as e:
            print(f"  Warning: Could not drop {table}: {e}")
else:
    print(f"\n[{datetime.now().strftime('%H:%M:%S')}] Cleanup disabled - staging tables retained")

# Summary
print("\n" + "="*80)
print("PIPELINE COMPLETE")
print("="*80)
print(f"Run ID: {RUN_ID}")
print(f"Total worklist events: {worklist_total:,}")
print(f"Successfully added to target: {target_total:,}")

if worklist_total == target_total:
    print("✓ ALL events accounted for in target table")
else:
    print(f"⚠ WARNING: {worklist_total - target_total} events may be missing")

print("="*80)

# Update metadata to mark run as complete
spark.sql(f"""
    UPDATE {METADATA_TABLE}
    SET status = 'complete',
        merged_ts = current_timestamp()
    WHERE run_id = '{RUN_ID}'
""")

print("\nNext steps:")
print("1. Verify data quality in target table")
print("2. Run OPTIMIZE on target table if needed:")
print(f"   OPTIMIZE {TARGET_TABLE} WHERE EVENT_ID IN (SELECT EVENT_ID FROM {WORKLIST_TABLE})")
print("3. Monitor for any failed records (STATUS != 'Decoded')")