In [0]:
# Combined Blob Processing Notebook
# Acts as orchestrator when run without parameters
# Acts as shard processor when run with SHARD_ID parameter

from pyspark.sql import functions as F, types as T, Row
from pyspark.sql.window import Window
import time
import json
from datetime import datetime
import uuid
import io, re, random, string, os, traceback, tempfile, shutil, subprocess
import signal
from contextlib import contextmanager

# ==================== PARAMETERS ====================
# Create widgets for optional parameters
dbutils.widgets.text("RUN_ID", "")
dbutils.widgets.text("SHARDS", "8")  # Total number of parallel shards
dbutils.widgets.text("SHARD_ID", "")  # Empty means run as orchestrator

RUN_ID = dbutils.widgets.get("RUN_ID")
SHARDS = int(dbutils.widgets.get("SHARDS"))
SHARD_ID_STR = dbutils.widgets.get("SHARD_ID")

# Determine execution mode
IS_ORCHESTRATOR = (SHARD_ID_STR == "")
SHARD_ID = None if IS_ORCHESTRATOR else int(SHARD_ID_STR)

# Configuration
STAGING_DB = "4_prod.tmp"
TARGET_TABLE = "4_prod.bronze.mill_blob_text"
METADATA_TABLE = f"{STAGING_DB}.pipeline_metadata"
MAX_BLOB_SIZE = 16 * 1024 * 1024
LZW_TIMEOUT_SECONDS = 30
OCF_MARKER = b'ocf_blob\0'

# OCR Configuration
ENABLE_OCR = True
OCR_MAX_PAGES = 10
OCR_MAX_PDF_SIZE_MB = 50
OCR_LANG = "eng"

# ==================== SPARK CONFIGURATION ====================
# Optimize for Unity Catalog Shared cluster
# Note: Some configs are restricted on UC Shared clusters
try:
    # These should work on most clusters
    spark.conf.set("spark.sql.adaptive.enabled", "true")
    spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "false")
    spark.conf.set("spark.sql.shuffle.partitions", "320")
    spark.conf.set("spark.databricks.delta.optimizeWrite.enabled", "true")
    spark.conf.set("spark.databricks.delta.autoCompact.enabled", "true")
except Exception as e:
    print(f"Note: Some Spark configs could not be set: {e}")
    # Continue with defaults

# ==================== MODE DETECTION ====================
if IS_ORCHESTRATOR:
    print("="*80)
    print("RUNNING AS ORCHESTRATOR")
    print(f"Will launch {SHARDS} parallel shards")
    print("="*80)
else:
    print("="*80)
    print("RUNNING AS SHARD PROCESSOR")
    print(f"Shard {SHARD_ID}/{SHARDS}")
    print("="*80)

# ==================== ORCHESTRATOR MODE ====================
if IS_ORCHESTRATOR:
    # Get latest run if not specified
    if not RUN_ID:
        latest_run = spark.sql(f"""
            SELECT run_id, worklist_table 
            FROM {METADATA_TABLE}
            WHERE status = 'worklist_created'
            ORDER BY created_ts DESC
            LIMIT 1
        """).collect()
        
        if not latest_run:
            raise Exception("No worklist found! Run Notebook 1 first.")
        
        RUN_ID = latest_run[0]['run_id']
        WORKLIST_TABLE = latest_run[0]['worklist_table']
    else:
        run_info = spark.sql(f"""
            SELECT worklist_table 
            FROM {METADATA_TABLE}
            WHERE run_id = '{RUN_ID}'
        """).collect()
        
        if not run_info:
            raise Exception(f"Run {RUN_ID} not found!")
        
        WORKLIST_TABLE = run_info[0]['worklist_table']
    
    print(f"Run ID: {RUN_ID}")
    print(f"Worklist: {WORKLIST_TABLE}")
    
    # Check worklist status
    print("\nChecking worklist status...")
    status_df = spark.sql(f"""
        SELECT 
            status,
            COUNT(*) as count,
            SUM(compressed_size) as total_bytes
        FROM {WORKLIST_TABLE}
        GROUP BY status
    """)
    
    status_df.show()
    
    pending_count = status_df.filter(F.col("status") == "pending").select("count").collect()
    if not pending_count or pending_count[0]['count'] == 0:
        print("No pending events to process!")
        dbutils.notebook.exit("No work")
    
    total_pending = pending_count[0]['count']
    print(f"\nTotal pending events: {total_pending:,}")
    
    # Launch shards
    print(f"\n{'='*80}")
    print(f"LAUNCHING {SHARDS} PARALLEL SHARDS")
    print(f"{'='*80}")
    
    from concurrent.futures import ThreadPoolExecutor, as_completed
    
    def run_shard(shard_id):
        print(f"  Launching shard {shard_id}...")
        try:
            # Get the current notebook path
            current_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
            
            result = dbutils.notebook.run(
                current_path,  # Run itself!
                timeout_seconds=7200,
                arguments={
                    "RUN_ID": RUN_ID,
                    "SHARDS": str(SHARDS),
                    "SHARD_ID": str(shard_id)  # This makes it run as shard processor
                }
            )
            return (shard_id, "SUCCESS", result)
        except Exception as e:
            return (shard_id, "FAILED", str(e))
    
    # Launch all shards in parallel
    start_time = time.time()
    results = []
    
    with ThreadPoolExecutor(max_workers=SHARDS) as executor:
        futures = {executor.submit(run_shard, i): i for i in range(SHARDS)}
        
        for future in as_completed(futures):
            shard_id = futures[future]
            try:
                result = future.result()
                results.append(result)
                print(f"  Shard {result[0]} completed: {result[1]}")
            except Exception as e:
                results.append((shard_id, "FAILED", str(e)))
                print(f"  Shard {shard_id} failed: {str(e)}")
    
    elapsed = time.time() - start_time
    
    # Summary
    print("\n" + "-"*80)
    print("RESULTS SUMMARY:")
    for shard_id, status, message in results:
        print(f"  Shard {shard_id}: {status} - {message}")
    
    success_count = sum(1 for _, status, _ in results if status == "SUCCESS")
    print(f"\nCompleted: {success_count}/{SHARDS} shards successful")
    print(f"Total time: {elapsed:.1f}s")
    
    # Verify results
    print("\n" + "="*80)
    print("VERIFYING RESULTS")
    print("="*80)
    
    final_status = spark.sql(f"""
        SELECT 
            status,
            COUNT(*) as count
        FROM {WORKLIST_TABLE}
        GROUP BY status
    """)
    
    print("\nFinal worklist status:")
    final_status.show()
    
    # Collect batch table names and count events
    batch_tables = []
    total_processed = 0
    for shard_id in range(SHARDS):
        batch_table = f"{STAGING_DB}.batch_{RUN_ID}_shard_{shard_id:04d}"
        if spark.catalog.tableExists(batch_table):
            batch_tables.append(batch_table)
    
    # Check batch tables
    if batch_tables:
        print(f"\nBatch tables created: {len(batch_tables)}")
        for table in batch_tables:
            count = spark.table(table).count()
            print(f"  {table}: {count:,} events")
            total_processed += count
        print(f"\nTotal events processed: {total_processed:,}")
    else:
        print("\n⚠ Warning: No batch tables found")
    
    # Update metadata with batch tables and final count - single update to avoid conflicts
    completed_count = final_status.filter(F.col("status") == "completed").select("count").collect()
    if completed_count and batch_tables:
        completed_events = completed_count[0]['count']
        batch_tables_str = ",".join(batch_tables)
        
        spark.sql(f"""
            UPDATE {METADATA_TABLE}
            SET 
                batch_tables = '{batch_tables_str}',
                processed_events = {total_processed},
                status = 'processing_complete',
                completed_ts = current_timestamp()
            WHERE run_id = '{RUN_ID}'
        """)
        
        print(f"\n✓ Processing complete!")
        print(f"✓ Completed events in worklist: {completed_events:,}")
        print(f"✓ Total records in batch tables: {total_processed:,}")
        print(f"✓ Batch tables: {batch_tables_str}")
    
    print("\n" + "="*80)
    print("NEXT STEPS:")
    print("1. Validate batch tables if needed")
    print("2. Run merge process to combine batch tables into final bronze table")
    print("3. Clean up batch tables after successful merge")
    
    dbutils.notebook.exit("Orchestration complete")

# ==================== SHARD PROCESSOR MODE ====================
# All the processing logic for when running as a shard

# Library imports
import chardet
from charset_normalizer import detect as cn_detect
from ocflzw_decompress.lzw import LzwDecompress
from striprtf.striprtf import rtf_to_text
from bs4 import BeautifulSoup
import docx2txt
from openpyxl import load_workbook
import xlrd
import pdfplumber
from pdfminer.high_level import extract_text as pdfminer_extract_text

# Optional libraries
try:
    import magic
except Exception:
    magic = None

try:
    import olefile
    HAVE_OLE = True
except Exception:
    HAVE_OLE = False

try:
    import extract_msg
    HAVE_EXTRACT_MSG = True
except Exception:
    HAVE_EXTRACT_MSG = False

try:  
    import fitz
    HAVE_FITZ = True  
except Exception:  
    HAVE_FITZ = False  

try:  
    import pypdf
    HAVE_PYPDF = True  
except Exception:  
    HAVE_PYPDF = False  

try:  
    import ocrmypdf
    HAVE_OCRMYPDF = True  
except Exception:  
    HAVE_OCRMYPDF = False  

try:  
    import pytesseract
    from PIL import Image
    HAVE_TESS = True  
except Exception:  
    HAVE_TESS = False

# Get run information
if not RUN_ID:
    raise Exception("RUN_ID required for shard processing")

run_info = spark.sql(f"""
    SELECT worklist_table 
    FROM {METADATA_TABLE}
    WHERE run_id = '{RUN_ID}'
""").collect()

if not run_info:
    raise Exception(f"Run {RUN_ID} not found!")

WORKLIST_TABLE = run_info[0]['worklist_table']

print(f"Run ID: {RUN_ID}")
print(f"Worklist: {WORKLIST_TABLE}")

# ==================== HELPER FUNCTIONS (OPTIMIZED) ====================

class TimeoutException(Exception):
    pass

def safe_numeric(value, default=None):
    if value is None or value == '':
        return default
    try:
        if isinstance(value, (int, float)):
            return int(value)
        return int(float(str(value)))
    except Exception:
        return default

def combine_blob_chunks(chunks):
    """OPTIMIZED: Combine blob chunks using join"""
    return b"".join(c for c in (chunks or []) if c)

def remove_ocf_wrapper_aggressive(data: bytes):
    """Remove ALL occurrences of OCF marker"""
    try:
        if not data:
            return data
        if data.endswith(OCF_MARKER):
            data = data[:-len(OCF_MARKER)]
        if OCF_MARKER in data:
            data = b''.join(data.split(OCF_MARKER))
        return data
    except Exception:
        return data

def remove_ocf_wrapper_conservative(data: bytes):
    """Only remove trailing OCF marker"""
    try:
        if not data:
            return data
        if data.endswith(OCF_MARKER):
            return data[:-len(OCF_MARKER)]
        return data
    except Exception:
        return data

def decompress_lzw_with_timeout(data: bytes, timeout_seconds=LZW_TIMEOUT_SECONDS):
    """Decompress LZW with timeout protection"""
    try:
        if timeout_seconds <= 0 or len(data) < 100000:
            return bytes(LzwDecompress().decompress(data))
        
        import threading
        if threading.current_thread() is threading.main_thread():
            import signal
            
            class TimeoutError(Exception):
                pass
            
            def timeout_handler(signum, frame):
                raise TimeoutError(f"LZW timeout after {timeout_seconds}s")
            
            old_handler = signal.signal(signal.SIGALRM, timeout_handler)
            signal.alarm(timeout_seconds)
            try:
                result = bytes(LzwDecompress().decompress(data))
                signal.alarm(0)
                return result
            finally:
                signal.signal(signal.SIGALRM, old_handler)
        else:
            return bytes(LzwDecompress().decompress(data))
    except Exception as e:
        raise e

def decompress_blob_simple(raw: bytes, compression_cd, chunk_count: int, metrics: dict):
    """OPTIMIZED: Fixed indentation and early ocf_count recording"""
    if not raw:
        return None, "Empty content"
    
    ocf_count = raw.count(OCF_MARKER)
    metrics["ocf_marker_count"] = ocf_count
    
    try:
        cd = int(compression_cd) if compression_cd is not None else None
    except Exception:
        cd = None
    
    if cd == 727:  # Uncompressed
        metrics["decompression_strategy"] = "uncompressed"
        return (remove_ocf_wrapper_aggressive(raw) if ocf_count > 0 else raw), None
    
    if cd != 728:  # Not LZW
        return None, f"Unknown compression: {compression_cd}"
    
    # LZW decompression strategies
    if chunk_count > 10 or ocf_count > 10:
        strategies = [("aggressive", remove_ocf_wrapper_aggressive),
                      ("raw", lambda d: d)]
    else:
        strategies = [("aggressive", remove_ocf_wrapper_aggressive),
                      ("conservative", remove_ocf_wrapper_conservative),
                      ("raw", lambda d: d)]
    
    last_error = None
    for strategy_name, strategy_fn in strategies:
        try:
            cleaned = strategy_fn(raw)
            result = decompress_lzw_with_timeout(cleaned)
            metrics["decompression_strategy"] = f"lzw_{strategy_name}"
            return result, None
        except TimeoutException as e:
            metrics["timeout_hit"] = True
            metrics["timeout_stage"] = f"lzw_{strategy_name}"
            last_error = str(e)
            break
        except Exception as e:
            last_error = str(e)
            continue
    
    return None, f"LZW failed all attempts: {last_error}"

def calculate_printable_ratio(text, sample_size=1000):
    if not text:
        return 0.0
    if len(text) <= sample_size:
        sample = text
    else:
        sample = ''.join(random.choice(text) for _ in range(sample_size))
    printable = sum(1 for c in sample if c in string.printable)
    return printable / len(sample) if sample else 0.0

def guess_text(content: bytes):
    if not content:
        return None, None, 0.0
    
    ch = chardet.detect(content) or {}
    cn = cn_detect(content) or {}
    candidates = [ch.get('encoding'), cn.get('encoding'), 'utf-8', 'windows-1252', 'latin-1', 'ascii']
    best_decoded, best_encoding, best_ratio = None, None, 0.0
    
    for enc in candidates:
        if not enc:
            continue
        try:
            decoded = content.decode(enc, errors='ignore')
            r = calculate_printable_ratio(decoded)
            if r > best_ratio:
                best_ratio, best_decoded, best_encoding = r, decoded, enc
            if best_ratio > 0.95:
                break
        except Exception:
            continue
    
    return best_decoded, best_encoding, best_ratio

def clean_text(text):
    if not isinstance(text, str):
        return text
    
    cleaned = re.sub(r'<%.*?%>', '', text, flags=re.DOTALL)
    cleaned = cleaned.replace('|', '\n')
    cleaned = re.sub(r'\n{3,}', '\n\n', cleaned)
    cleaned = re.sub(r'\n+', '\n', cleaned)
    return cleaned.strip()

def detect_mime(content: bytes):
    if not content:
        return 'application/octet-stream'
    
    if content.startswith(b'%PDF-'):
        return 'application/pdf'
    
    if content.startswith(b'\x50\x4B\x03\x04'):
        head = content[:4096]
        if b'word/' in head:
            return 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
        if b'xl/' in head:
            return 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
        if b'ppt/' in head:
            return 'application/vnd.openxmlformats-officedocument.presentationml.presentation'
        return 'application/zip'
    
    if content.startswith(b'\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1'):
        return 'application/x-ole-storage'
    
    if content.startswith(b'{\\'):
        return 'text/rtf'
    
    if magic:
        try:
            return magic.Magic(mime=True).from_buffer(content) or 'application/octet-stream'
        except Exception:
            pass
    
    return 'application/octet-stream'

def classify_ole(data: bytes):
    if not (HAVE_OLE and data):
        return 'application/x-ole-storage'
    
    try:
        with olefile.OleFileIO(io.BytesIO(data)) as ole:
            if ole.exists('WordDocument'):
                return 'application/msword'
            if ole.exists('Workbook') or ole.exists('Book'):
                return 'application/vnd.ms-excel'
            if ole.exists('PowerPoint Document'):
                return 'application/vnd.ms-powerpoint'
            if ole.exists('__properties_version1.0') and (
                ole.exists('__recip_version1.0') or ole.exists('__attach_version1.0')
            ):
                return 'application/vnd.ms-outlook'
            return 'application/x-ole-storage'
    except Exception:
        return 'application/x-ole-storage'

def refine_mime_with_ole(content_type, data: bytes):
    if content_type == 'application/x-ole-storage':
        return classify_ole(data)
    return content_type

# [Include all the text extraction functions from the previous version]
# I'll include just the key ones here for brevity

def extract_pdf_with_pypdf(content: bytes):
    if not HAVE_PYPDF:
        return None
    try:
        reader = pypdf.PdfReader(io.BytesIO(content), strict=False)
        if getattr(reader, "is_encrypted", False):
            try:
                reader.decrypt("")
            except Exception:
                pass
        parts = []
        for page in reader.pages:
            t = page.extract_text() or ""
            if t.strip():
                parts.append(t)
        return "\n".join(parts) if parts else None
    except Exception:
        return None

def extract_pdf_with_pymupdf(content: bytes):
    if not HAVE_FITZ:
        return None
    try:
        doc = fitz.open(stream=content, filetype="pdf")
        if doc.needs_pass:
            try:
                doc.authenticate("")
            except Exception:
                pass
        parts = []
        for page in doc:
            t = page.get_text("text") or ""
            if t.strip():
                parts.append(t)
        return "\n".join(parts) if parts else None
    except Exception:
        return None

def extract_text_from_docx(content):
    try:
        return docx2txt.process(io.BytesIO(content))
    except Exception:
        return None

def extract_text_from_excel(content):
    try:
        wb = load_workbook(io.BytesIO(content), read_only=True, data_only=True)
        parts = []
        for sheet in wb.sheetnames:
            ws = wb[sheet]
            for row in ws.iter_rows(values_only=True):
                row_text = ' '.join(str(cell) for cell in row if cell is not None)
                if row_text.strip():
                    parts.append(row_text)
        return '\n'.join(parts)
    except Exception:
        return None

# [Include all other extraction functions from previous version...]

def parse_blob_content(content: bytes, provided_type=None):
    """Main content parsing function"""
    if not content:
        return None, None, None
    
    content_type = provided_type or detect_mime(content)
    content_type = refine_mime_with_ole(content_type, content)
    
    # Quick routing based on content type
    if content_type == 'application/pdf':
        # Try fast extractors first
        txt = extract_pdf_with_pypdf(content) or extract_pdf_with_pymupdf(content)
        if txt:
            return clean_text(txt), content_type, 'utf-8'
        return "[PDF - extraction failed]", content_type, None
    
    elif content_type in ('application/vnd.openxmlformats-officedocument.wordprocessingml.document',):
        txt = extract_text_from_docx(content)
        if txt:
            return clean_text(txt), content_type, 'utf-8'
    
    elif content_type in ('application/vnd.ms-excel', 
                          'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'):
        txt = extract_text_from_excel(content)
        if txt:
            return clean_text(txt), content_type, 'utf-8'
    
    # Default text extraction
    decoded, best_enc, ratio = guess_text(content)
    if decoded:
        return clean_text(decoded), content_type, best_enc
    
    return f"[Binary data, unable to decode]", content_type, None

# ==================== MAIN PROCESSING UDF ====================

process_output_schema = T.StructType([
    T.StructField("EVENT_ID", T.LongType(), True),
    T.StructField("VALID_UNTIL_DT_TM", T.TimestampType(), True),
    T.StructField("VALID_FROM_DT_TM", T.TimestampType(), True),
    T.StructField("UPDT_DT_TM", T.TimestampType(), True),
    T.StructField("UPDT_ID", T.LongType(), True),
    T.StructField("UPDT_TASK", T.LongType(), True),
    T.StructField("UPDT_CNT", T.LongType(), True),
    T.StructField("UPDT_APPLCTX", T.LongType(), True),
    T.StructField("LAST_UTC_TS", T.TimestampType(), True),
    T.StructField("ADC_UPDT", T.TimestampType(), True),
    T.StructField("BLOB_BINARY", T.BinaryType(), True),
    T.StructField("CONTENT_TYPE", T.StringType(), True),
    T.StructField("ENCODING", T.StringType(), True),
    T.StructField("BLOB_TEXT", T.StringType(), True),
    T.StructField("BINARY_SIZE", T.LongType(), True),
    T.StructField("TEXT_LENGTH", T.LongType(), True),
    T.StructField("STATUS", T.StringType(), True),
    T.StructField("anon_text", T.StringType(), True),
    T.StructField("metrics", T.StringType(), True)
])

@F.udf(returnType=process_output_schema)
def process_blob_udf(event_id, chunks_data, valid_until, valid_from, updt_dt, 
                     updt_id, updt_task, updt_cnt, updt_applctx, 
                     last_utc, adc_updt):
    """Process a single event's chunks"""
    
    metrics = {
        "combine_ms": 0,
        "decompress_ms": 0,
        "parse_ms": 0,
        "chunk_count": 0,
        "ocf_marker_count": 0,
        "decompression_strategy": None,
        "timeout_hit": False,
        "timeout_stage": None,
        "error": None
    }
    
    try:
        import time
        t0 = time.perf_counter()
        
        # Sort chunks by sequence number
        sorted_chunks = sorted(chunks_data, key=lambda x: x['BLOB_SEQ_NUM'] or 0)
        chunks = [c['BLOB_CONTENTS'] for c in sorted_chunks if c['BLOB_CONTENTS']]
        metrics["chunk_count"] = len(chunks)
        
        # Get first chunk's compression code
        compression_cd = sorted_chunks[0]['COMPRESSION_CD'] if sorted_chunks else None
        
        # Combine chunks - OPTIMIZED
        combined = combine_blob_chunks(chunks)
        metrics["combine_ms"] = int((time.perf_counter() - t0) * 1000)
        
        # Check size
        if len(combined) > MAX_BLOB_SIZE:
            return Row(
                EVENT_ID=safe_numeric(event_id),
                VALID_UNTIL_DT_TM=valid_until,
                VALID_FROM_DT_TM=valid_from,
                UPDT_DT_TM=updt_dt,
                UPDT_ID=safe_numeric(updt_id),
                UPDT_TASK=safe_numeric(updt_task),
                UPDT_CNT=safe_numeric(updt_cnt),
                UPDT_APPLCTX=safe_numeric(updt_applctx),
                LAST_UTC_TS=last_utc,
                ADC_UPDT=adc_updt,
                BLOB_BINARY=None,
                CONTENT_TYPE=None,
                ENCODING=None,
                BLOB_TEXT=None,
                BINARY_SIZE=len(combined),
                TEXT_LENGTH=None,
                STATUS=f"Compressed Too Large: {len(combined)} bytes",
                anon_text=None,
                metrics=json.dumps(metrics)
            )
        
        # Decompress - OPTIMIZED
        t0 = time.perf_counter()
        decompressed, dec_err = decompress_blob_simple(combined, compression_cd, len(chunks), metrics)
        metrics["decompress_ms"] = int((time.perf_counter() - t0) * 1000)
        
        if decompressed is None:
            return Row(
                EVENT_ID=safe_numeric(event_id),
                VALID_UNTIL_DT_TM=valid_until,
                VALID_FROM_DT_TM=valid_from,
                UPDT_DT_TM=updt_dt,
                UPDT_ID=safe_numeric(updt_id),
                UPDT_TASK=safe_numeric(updt_task),
                UPDT_CNT=safe_numeric(updt_cnt),
                UPDT_APPLCTX=safe_numeric(updt_applctx),
                LAST_UTC_TS=last_utc,
                ADC_UPDT=adc_updt,
                BLOB_BINARY=None,
                CONTENT_TYPE=None,
                ENCODING=None,
                BLOB_TEXT=None,
                BINARY_SIZE=None,
                TEXT_LENGTH=None,
                STATUS=dec_err or "Decompression failed",
                anon_text=None,
                metrics=json.dumps(metrics)
            )
        
        # Check decompressed size
        if len(decompressed) > MAX_BLOB_SIZE:
            return Row(
                EVENT_ID=safe_numeric(event_id),
                VALID_UNTIL_DT_TM=valid_until,
                VALID_FROM_DT_TM=valid_from,
                UPDT_DT_TM=updt_dt,
                UPDT_ID=safe_numeric(updt_id),
                UPDT_TASK=safe_numeric(updt_task),
                UPDT_CNT=safe_numeric(updt_cnt),
                UPDT_APPLCTX=safe_numeric(updt_applctx),
                LAST_UTC_TS=last_utc,
                ADC_UPDT=adc_updt,
                BLOB_BINARY=None,
                CONTENT_TYPE=None,
                ENCODING=None,
                BLOB_TEXT=None,
                BINARY_SIZE=len(decompressed),
                TEXT_LENGTH=None,
                STATUS=f"Decompressed too large: {len(decompressed)} bytes",
                anon_text=None,
                metrics=json.dumps(metrics)
            )
        
        # Parse content - OPTIMIZED
        t0 = time.perf_counter()
        blob_text, content_type, encoding = parse_blob_content(decompressed, None)
        metrics["parse_ms"] = int((time.perf_counter() - t0) * 1000)
        
        # Determine status
        if blob_text:
            if isinstance(blob_text, str):
                if blob_text.startswith("[") and "]" in blob_text[:200]:
                    status = blob_text.split("]")[0][1:]
                else:
                    status = "Decoded"
            else:
                status = "Decoded"
        else:
            status = "Failed to decode"
        
        # Safe encode text
        if blob_text and isinstance(blob_text, str):
            blob_text = blob_text.encode('utf-8', errors='ignore').decode('utf-8')
            if len(blob_text) > 1000000:
                blob_text = blob_text[:1000000]
        
        return Row(
            EVENT_ID=safe_numeric(event_id),
            VALID_UNTIL_DT_TM=valid_until,
            VALID_FROM_DT_TM=valid_from,
            UPDT_DT_TM=updt_dt,
            UPDT_ID=safe_numeric(updt_id),
            UPDT_TASK=safe_numeric(updt_task),
            UPDT_CNT=safe_numeric(updt_cnt),
            UPDT_APPLCTX=safe_numeric(updt_applctx),
            LAST_UTC_TS=last_utc,
            ADC_UPDT=adc_updt,
            BLOB_BINARY=None,
            CONTENT_TYPE=content_type,
            ENCODING=encoding,
            BLOB_TEXT=blob_text,
            BINARY_SIZE=len(decompressed) if decompressed else None,
            TEXT_LENGTH=len(blob_text) if blob_text else None,
            STATUS=status,
            anon_text=None,
            metrics=json.dumps(metrics)
        )
        
    except Exception as e:
        metrics["error"] = str(e)[:500]
        return Row(
            EVENT_ID=safe_numeric(event_id),
            VALID_UNTIL_DT_TM=valid_until,
            VALID_FROM_DT_TM=valid_from,
            UPDT_DT_TM=updt_dt,
            UPDT_ID=safe_numeric(updt_id),
            UPDT_TASK=safe_numeric(updt_task),
            UPDT_CNT=safe_numeric(updt_cnt),
            UPDT_APPLCTX=safe_numeric(updt_applctx),
            LAST_UTC_TS=last_utc,
            ADC_UPDT=adc_updt,
            BLOB_BINARY=None,
            CONTENT_TYPE=None,
            ENCODING=None,
            BLOB_TEXT=None,
            BINARY_SIZE=None,
            TEXT_LENGTH=None,
            STATUS=f"Error: {str(e)[:200]}",
            anon_text=None,
            metrics=json.dumps(metrics)
        )

# ==================== MAIN SHARD PROCESSING ====================

total_start = time.time()

# 1) Select pending events for this shard deterministically
print(f"\n[{datetime.now().strftime('%H:%M:%S')}] Selecting events for shard {SHARD_ID}...")

events_df = spark.sql(f"""
    SELECT EVENT_ID, ADC_UPDT, chunks_data, compressed_size
    FROM {WORKLIST_TABLE}
    WHERE status = 'pending'
      AND pmod(xxhash64(EVENT_ID), {SHARDS}) = {SHARD_ID}
""")

# 2) Count and check if there's work
event_count = events_df.count()
print(f"[{datetime.now().strftime('%H:%M:%S')}] Shard {SHARD_ID}: Found {event_count:,} events to process")

if event_count == 0:
    print(f"Shard {SHARD_ID} has no work.")
    dbutils.notebook.exit(f"Shard {SHARD_ID}: No work")

# 3) Repartition for parallel UDF execution
desired_partitions = max(160, min(320, (event_count // 200) or 1))
print(f"[{datetime.now().strftime('%H:%M:%S')}] Repartitioning to {desired_partitions} partitions...")
events_df = events_df.repartition(desired_partitions, F.col("EVENT_ID"))

# 4) Build first-chunk projection
print(f"[{datetime.now().strftime('%H:%M:%S')}] Extracting metadata from first chunks...")
first = (events_df
    .select(
        "EVENT_ID", "ADC_UPDT", "chunks_data",
        F.element_at(F.col("chunks_data"), 1).alias("first_chunk")
    )
    .select(
        "EVENT_ID", "chunks_data",
        F.col("first_chunk.VALID_UNTIL_DT_TM").alias("VALID_UNTIL_DT_TM"),
        F.col("first_chunk.VALID_FROM_DT_TM").alias("VALID_FROM_DT_TM"),
        F.col("first_chunk.UPDT_DT_TM").alias("UPDT_DT_TM"),
        F.col("first_chunk.UPDT_ID").alias("UPDT_ID"),
        F.col("first_chunk.UPDT_TASK").alias("UPDT_TASK"),
        F.col("first_chunk.UPDT_CNT").alias("UPDT_CNT"),
        F.col("first_chunk.UPDT_APPLCTX").alias("UPDT_APPLCTX"),
        F.col("first_chunk.LAST_UTC_TS").alias("LAST_UTC_TS"),
        F.col("first_chunk.ADC_UPDT").alias("ADC_UPDT")
    )
)

# 5) Apply UDF
print(f"[{datetime.now().strftime('%H:%M:%S')}] Processing blobs with UDF...")
processed = first.select(
    process_blob_udf(
        F.col("EVENT_ID"),
        F.col("chunks_data"),
        F.col("VALID_UNTIL_DT_TM"),
        F.col("VALID_FROM_DT_TM"),
        F.col("UPDT_DT_TM"),
        F.col("UPDT_ID"),
        F.col("UPDT_TASK"),
        F.col("UPDT_CNT"),
        F.col("UPDT_APPLCTX"),
        F.col("LAST_UTC_TS"),
        F.col("ADC_UPDT")
    ).alias("r")
).select("r.*")

# 6) Write to batch table for this shard
batch_table = f"{STAGING_DB}.batch_{RUN_ID}_shard_{SHARD_ID:04d}"
print(f"[{datetime.now().strftime('%H:%M:%S')}] Writing results to batch table: {batch_table}")

# Write with optimized partitions for later merge
write_partitions = max(10, min(100, event_count // 1000))
processed.repartition(write_partitions).write.mode("overwrite").saveAsTable(batch_table)

# 7) Mark this shard's rows as completed
print(f"[{datetime.now().strftime('%H:%M:%S')}] Updating worklist status...")
events_df.select("EVENT_ID").distinct().createOrReplaceTempView("processed_ids")

spark.sql(f"""
    UPDATE {WORKLIST_TABLE} AS w
    SET status = 'completed', process_ts = current_timestamp()
    WHERE status = 'pending'
      AND pmod(xxhash64(w.EVENT_ID), {SHARDS}) = {SHARD_ID}
      AND EXISTS (SELECT 1 FROM processed_ids p WHERE p.EVENT_ID = w.EVENT_ID)
""")

# 8) Report metrics
elapsed = time.time() - total_start
print("\n" + "="*80)
print(f"SHARD {SHARD_ID} COMPLETE")
print("="*80)
print(f"Events processed: {event_count:,}")
print(f"Batch table: {batch_table}")
print(f"Total time: {elapsed:.1f}s")
print(f"Processing rate: {event_count/elapsed:.1f} events/sec")

# 9) Don't update metadata from shards - let orchestrator do it to avoid conflicts
# Each shard just reports its results
print(f"Shard {SHARD_ID} completed successfully")

# Return success with batch table info
dbutils.notebook.exit(f"Shard {SHARD_ID}: Processed {event_count} events -> {batch_table}")