In [0]:
from pyspark.sql.types import *
import pyspark.sql.functions as F
from pyspark.sql.functions import pandas_udf
from pyspark.sql.functions import PandasUDFType
from contextlib import contextmanager
import pandas as pd
from ocflzw_decompress.lzw import LzwDecompress
from striprtf.striprtf import rtf_to_text
from bs4 import BeautifulSoup
import re
import chardet
import traceback
from charset_normalizer import detect
import magic
import random
import string
import io
from datetime import datetime
import docx2txt
import xlrd
from openpyxl import load_workbook
import sys
import signal
import time
import pdfplumber
import tempfile
import subprocess
import os
import threading




# At the beginning of your script
def error_handler(exctype, value, tb):
    print("Uncaught exception:", file=sys.stderr)
    print("Type:", exctype, file=sys.stderr)
    print("Value:", value, file=sys.stderr)
    traceback.print_tb(tb)

sys.excepthook = error_handler

In [0]:
def get_max_adc_updt(table_name):
    default_date = datetime(1980, 1, 1)  # Python datetime object
    try:
        result = spark.sql(f"SELECT MAX(ADC_UPDT) AS max_date FROM {table_name}")
        max_date = result.select(F.max("max_date").alias("max_date")).first()["max_date"]
        return max_date if max_date is not None else default_date
    except:
        return default_date  

def table_exists(table_name):
    try:
        result = spark.sql(f"SELECT 1 FROM {table_name} LIMIT 1")
        return result.first() is not None
    except:
        return False

In [0]:
# Define helper functions

def format_size(size_bytes):
    for unit in ['bytes', 'KB', 'MB', 'GB', 'TB']:
        if size_bytes < 1024.0:
            return f"{size_bytes:.2f} {unit}"
        size_bytes /= 1024.0
    return f"{size_bytes:.2f} PB"



def combine_blob_chunks(chunks):
    combined = bytearray()
    for chunk in chunks:
        combined.extend(chunk)
    return bytes(combined)

def decompress_blob(blob_contents, compression_cd):
    if blob_contents and isinstance(blob_contents, (bytes, bytearray)):
        try:
            if compression_cd == 728:  # LZW compression
                lzw = LzwDecompress()
                return bytes(lzw.decompress(blob_contents))
            elif compression_cd == 727:  # No compression
                return bytes(blob_contents)
            else:
                return f"Unknown compression type: {compression_cd}"
        except Exception as e:
            return f"Decompression error: {str(e)}"
    return None


def enhanced_content_type_detection(content):
    """Enhanced content type detection with additional checks"""
    try:
        mime = magic.Magic(mime=True)
        libmagic_type = mime.from_buffer(content)
        
        # Check for specific file signatures
        if content.startswith(b'%PDF-'):
            return 'application/pdf'
        elif content.startswith(b'\x50\x4B\x03\x04'):
            # Check for specific Office formats
            if b'word/' in content[:200]:
                return 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
            elif b'xl/' in content[:200]:
                return 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
            else:
                return 'application/zip'
        elif content.startswith(b'\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1'):
            # Old Office format
            return 'application/msword'
        elif content.startswith(b'{\\'): 
            # RTF
            return 'text/rtf'
        
        return libmagic_type
    except Exception as e:
        return None

def extract_text_from_binary(content, content_type):
    """Modified to return more specific error information"""
    if len(content) < 250:
        return None
        
    try:
        if content_type == 'application/pdf':
            text = parse_pdf(content)
            if not text or text.startswith('[PDF Content'):
                # Try alternate PDF parsing method
                with io.BytesIO(content) as pdf_file:
                    try:
                        with pdfplumber.open(pdf_file) as pdf:
                            text = '\n'.join(
                                page.extract_text(x_tolerance=3, y_tolerance=3) 
                                for page in pdf.pages 
                                if page.extract_text(x_tolerance=3, y_tolerance=3)
                            )
                        if text.strip():
                            return text
                    except Exception as e:
                        return f"[PDF Content - pdfplumber Error: {str(e)}]"
            return text
        elif content_type in ['application/msword', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document']:
            return extract_text_from_doc(content)
        elif content_type in ['application/vnd.ms-excel', 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet']:
            return extract_text_from_excel(content)
        else:
            return None
    except Exception as e:
        return f"[Binary Content - Extraction Error: {str(e)}]"


def extract_text_from_doc(content):
    try:
        return docx2txt.process(io.BytesIO(content))
    except:
        return None

def extract_text_from_excel(content):
    try:
        workbook = load_workbook(io.BytesIO(content))
        text = []
        for sheet in workbook.sheetnames:
            for row in workbook[sheet].iter_rows(values_only=True):
                text.append(' '.join(str(cell) for cell in row if cell))
        return '\n'.join(text)
    except:
        try:
            workbook = xlrd.open_workbook(file_contents=content)
            text = []
            for sheet in workbook.sheets():
                for row in range(sheet.nrows):
                    text.append(' '.join(str(cell.value) for cell in sheet.row(row)))
            return '\n'.join(text)
        except:
            return None
        

def parse_blob_content(content, provided_content_type=None):
    try:
        if not content:
            return None, None, None
        
        # Enhanced content type detection
        content_type = provided_content_type or enhanced_content_type_detection(content)

        if content_type.startswith('image/') or content_type == 'application/zip':
            return f"[{content_type} Content]", content_type, None

        # Try to extract text based on content type
        extracted_text = extract_text_from_binary(content, content_type)
        if extracted_text:
            return clean_text(extracted_text), content_type, 'utf-8'

        # If text extraction failed, proceed with encoding detection and decoding
        chardet_result = chardet.detect(content)
        charset_normalizer_result = detect(content)
        
        encodings = [
            chardet_result['encoding'],
            charset_normalizer_result.get('encoding'),
            'utf-8',
            'iso-8859-1',
            'windows-1252',
            'ascii'
        ]
        
        best_decoded = None
        best_encoding = None
        max_printable_ratio = 0.6

        for encoding in encodings:
            if encoding:
                try:
                    decoded = content.decode(encoding)
                    printable_ratio = calculate_printable_ratio(decoded)
                    if printable_ratio > max_printable_ratio:
                        best_decoded = decoded
                        best_encoding = encoding
                        max_printable_ratio = printable_ratio
                    if max_printable_ratio > 0.9:  # If we find a good enough encoding, stop searching
                        break
                except UnicodeDecodeError:
                    continue

        # If we still haven't found a good decoding, try binary file processing
        if max_printable_ratio <= 0.9 and len(content) > 250:
            binary_content_types = [
                'application/pdf',
                'application/msword',
                'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
                'application/vnd.ms-excel',
                'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
                'application/rtf'
            ]
            
            for binary_type in binary_content_types:
                extracted_text = extract_text_from_binary(content, binary_type)
                if extracted_text:
                    printable_ratio = calculate_printable_ratio(extracted_text)
                    if printable_ratio > max_printable_ratio:
                        best_decoded = extracted_text
                        best_encoding = 'utf-8'  # Assume UTF-8 for extracted text
                        max_printable_ratio = printable_ratio
                        content_type = binary_type
                    if max_printable_ratio > 0.8:  # If we find a good enough extraction, stop searching
                        break

        if best_decoded is None:
            return f"[Binary data, unable to decode. Best printable ratio: {max_printable_ratio:.2f}]", content_type, None

        if content_type == "text/rtf":
            return rtf_to_text(best_decoded), content_type, best_encoding
        elif content_type in ["text/html", "text/xml", "application/xhtml+xml"]:
            soup = BeautifulSoup(best_decoded, 'html.parser')
            return soup.get_text(separator='\n', strip=True), content_type, best_encoding
        else:
            return clean_text(best_decoded), content_type, best_encoding

    except Exception as e:
        error_msg = f"Error in parse_blob_content: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)  # This will print to the Spark logs
        return error_msg, None, None

def calculate_printable_ratio(text, sample_size=1000):
    if not text:
        return 0.0
    if text.startswith("[Binary data, unable to decode."):
        return 0.0
    elif text == "[PDF Content - Error extracting text]":
        return 0.0
    
    total_length = len(text)
    if total_length <= sample_size:
        # If the text is shorter than the sample size, use the entire text
        sample = text
    else:
        # Take a random sample of characters
        sample = ''.join(random.choice(text) for _ in range(sample_size))
    
    printable_count = 0
    total_count = 0
    for c in sample:
        total_count += 1
        if c in string.printable:
            printable_count += 1
    
    return printable_count / total_count if total_count > 0 else 0.0

def parse_pdf(content):
    """
    Enhanced PDF parsing using pdfplumber with thorough error handling and detailed logging
    """
    try:
        # Track processing steps for debugging
        processing_log = []
        processing_log.append("Starting PDF parsing")

        # First try to repair the PDF if needed
        repaired_content = repair_pdf(content)
        if repaired_content:
            content = repaired_content
            processing_log.append("PDF repair successful")
        else:
            processing_log.append("PDF repair not needed or failed")

        with io.BytesIO(content) as pdf_file:
            try:
                with pdfplumber.open(pdf_file) as pdf:
                    extracted_text = []
                    processing_log.append(f"Successfully opened PDF with {len(pdf.pages)} pages")
                    
                    for page_num, page in enumerate(pdf.pages, 1):
                        processing_log.append(f"Processing page {page_num}")
                        
                        try:
                            # Extract text with position information
                            words = page.extract_words(
                                x_tolerance=3,
                                y_tolerance=3,
                                keep_blank_chars=False,
                                use_text_flow=False
                            )
                            
                            if not words:
                                processing_log.append(f"No words found on page {page_num}")
                                continue
                                
                            # Sort words by vertical position first, then horizontal
                            lines = []
                            current_line_y = None
                            line_words = []
                            
                            for word in words:
                                if current_line_y is None or abs(float(word['top']) - float(current_line_y)) > 3:
                                    if line_words:
                                        line_words.sort(key=lambda w: float(w['x0']))
                                        lines.append(line_words)
                                    line_words = [word]
                                    current_line_y = float(word['top'])
                                else:
                                    line_words.append(word)
                            
                            if line_words:
                                line_words.sort(key=lambda w: float(w['x0']))
                                lines.append(line_words)
                            
                            # Process each line
                            for line in lines:
                                line_text = ' '.join(word['text'] for word in line)
                                if line_text.strip():
                                    extracted_text.append(line_text)
                            
                            # Add page break if not the last page
                            if page_num < len(pdf.pages):
                                extracted_text.append("\n" + "-"*70 + "\n")
                                
                        except Exception as page_error:
                            processing_log.append(f"Error processing page {page_num}: {str(page_error)}")
                            continue
                    
                    final_text = '\n'.join(extracted_text)
                    if final_text.strip():
                        processing_log.append("Successfully extracted text")
                        return final_text
                    else:
                        processing_log.append("No text content found in PDF")
                        return "[PDF Content - Empty Document]"
                        
            except Exception as pdf_error:
                processing_log.append(f"Error opening PDF: {str(pdf_error)}")
                # If the first attempt fails, try with the original content
                if repaired_content:
                    processing_log.append("Retrying with original content")
                    with pdfplumber.open(io.BytesIO(content)) as pdf:
                        text = '\n'.join(page.extract_text() for page in pdf.pages)
                        if text.strip():
                            return text
                raise pdf_error

    except Exception as e:
        processing_log.append(f"Fatal error: {str(e)}")
        print("PDF Processing Log:")
        for log in processing_log:
            print(f"  - {log}")
        return "[PDF Content - Error extracting text]"
    
        
def repair_pdf(content):
    """
    Attempt to repair corrupted PDF files using pdftk
    """
    input_path = None
    output_path = None
    
    try:
        # Create temporary files
        with tempfile.NamedTemporaryFile(suffix='.pdf', delete=False) as input_pdf:
            with tempfile.NamedTemporaryFile(suffix='.pdf', delete=False) as output_pdf:
                input_path = input_pdf.name
                output_path = output_pdf.name

        # Write content to temp file
        with open(input_path, 'wb') as f:
            f.write(content)
        
        try:
            # Run pdftk repair
            result = subprocess.run(
                ['pdftk', input_path, 'output', output_path],
                capture_output=True,
                timeout=300
            )
            
            # Check process result
            if result.returncode != 0:
                print(f"pdftk repair failed with return code {result.returncode}")
                print(f"stderr: {result.stderr.decode('utf-8', errors='ignore')}")
                return None
            
            # Read repaired content
            if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
                with open(output_path, 'rb') as f:
                    return f.read()
            else:
                print("Repair failed: Output file is empty or missing")
                return None
                
        except subprocess.TimeoutExpired:
            print("pdftk repair timed out after 300 seconds")
            return None
        except subprocess.SubprocessError as e:
            print(f"subprocess error running pdftk: {str(e)}")
            return None
            
    except Exception as e:
        print(f"Error in repair_pdf: {str(e)}")
        return None
        
    finally:
        # Cleanup temp files
        for path in [input_path, output_path]:
            if path and os.path.exists(path):
                try:
                    os.unlink(path)
                except Exception as e:
                    print(f"Error cleaning up temporary file {path}: {str(e)}")


def clean_text(text):
    cleaned = re.sub(r'<%.*?%>', '', text)
    cleaned = cleaned.replace('|', '\n')
    cleaned = re.sub(r'\n{3,}', '\n\n', cleaned)
    cleaned = re.sub(r'\n+', '\n', cleaned)  # Remove multiple consecutive newlines
    return cleaned.strip()


def clean_non_printable(text):
    # Remove two or more consecutive non-printable characters
    cleaned = re.sub(r'[^\x20-\x7E]{2,}', '\n', text)
    # Replace single non-printable characters with a space
    cleaned = ''.join(char if char in string.printable else ' ' for char in cleaned)
    # Remove extra spaces
    cleaned = re.sub(r'\s+', ' ', cleaned).strip()
    return cleaned



In [0]:

output_schema = StructType([
    StructField("EVENT_ID", LongType(), True),
    StructField("VALID_UNTIL_DT_TM", TimestampType(), True),
    StructField("VALID_FROM_DT_TM", TimestampType(), True),
    StructField("UPDT_DT_TM", TimestampType(), True),
    StructField("UPDT_ID", LongType(), True),
    StructField("UPDT_TASK", LongType(), True),
    StructField("UPDT_CNT", LongType(), True),
    StructField("UPDT_APPLCTX", LongType(), True),
    StructField("LAST_UTC_TS", TimestampType(), True),
    StructField("ADC_UPDT", TimestampType(), True),
    StructField("BLOB_BINARY", BinaryType(), True),
    StructField("CONTENT_TYPE", StringType(), True),
    StructField("ENCODING", StringType(), True),
    StructField("BLOB_TEXT", StringType(), True),
    StructField("BINARY_SIZE", LongType(), True),
    StructField("TEXT_LENGTH", LongType(), True),
    StructField("STATUS", StringType(), True)
])


# Global set to store failed EVENT_IDs
failed_event_ids = set()




def remove_ocf_wrapper(blob_contents):
      """Remove OCF wrapper from any content type"""
      try:
        ocf_marker = b'ocf_blob\0'
        
        # Handle case where marker occurs at end
        if blob_contents.endswith(ocf_marker):
            blob_contents = blob_contents[:-len(ocf_marker)]
            
        # Handle case where marker might occur multiple times
        # Split on marker and rejoin, effectively removing all instances
        if ocf_marker in blob_contents:
            parts = blob_contents.split(ocf_marker)
            blob_contents = b''.join(parts)
            
        return blob_contents
      except Exception as e:
        return None
    
def process_single_row(row, row_timeout=120):
    try:
        total_blob_length = row['TOTAL_BLOB_LENGTH']
        if total_blob_length > 8 * 1024 * 1024:  # 8 MB
            status = f"File Too Large: {format_size(total_blob_length)}"
            return create_result_dict(row, status=status)
        
        # Combine chunks
        blob_contents = combine_blob_chunks(row['BLOB_CONTENTS_LIST'])
        
        # Decompress if needed
        decompressed = decompress_blob(blob_contents, row['COMPRESSION_CD'])
        if isinstance(decompressed, str):
            return create_result_dict(row, status=decompressed)
            
        if decompressed is not None:
            # Remove OCF wrapper if present
            cleaned_content = remove_ocf_wrapper(decompressed)
            if cleaned_content is not None:
                decompressed = cleaned_content
            
            # Detect content type efficiently
            content_type = enhanced_content_type_detection(decompressed)
            
            # For RTF files, use optimized handling with timeout protection
            if content_type == "text/rtf":
                try:
                    with timeout(60):  # 60-second timeout for RTF processing
                        # Decode with a simpler method first
                        content_str = decompressed.decode('latin-1', errors='ignore')
                        
                        # Try simplified extraction first (much faster)
                        blob_text = None
                        try:
                            # Simple non-regex text extraction
                            text = ''
                            in_control_word = False
                            in_brace = 0
                            for c in content_str[:1000000]:  # Process first 1MB max
                                if c == '\\':
                                    in_control_word = True
                                elif in_control_word and c.isspace():
                                    in_control_word = False
                                    text += ' '
                                elif not in_control_word and c == '{':
                                    in_brace += 1
                                elif not in_control_word and c == '}':
                                    in_brace -= 1
                                elif not in_control_word and not c.isspace() and in_brace >= 0:
                                    text += c
                            
                            blob_text = text.strip()
                            if len(blob_text) < 100:  # If we got very little text, try the standard method
                                blob_text = rtf_to_text(content_str)
                        except:
                            # Fall back to standard method
                            blob_text = rtf_to_text(content_str)
                            
                        return create_result_dict(row, decompressed, content_type, 'latin-1', 
                                                blob_text, status='Decoded')
                except ThreadTimeoutError:
                    # If RTF processing times out, use a very simple extraction
                    try:
                        # Emergency fallback - just grab anything that looks like text
                        import re
                        content_str = decompressed.decode('latin-1', errors='ignore')
                        text = re.sub(r'[^\x20-\x7E\n]', ' ', content_str)
                        text = re.sub(r'\s+', ' ', text).strip()
                        return create_result_dict(row, decompressed, content_type, 'latin-1',
                                                text, status='Decoded (fallback method)')
                    except:
                        return create_result_dict(row, decompressed, content_type, None, 
                                                "[RTF processing timed out]", status='Timeout')
            
            # Handle other content types
            blob_text, detected_type, encoding = parse_blob_content(decompressed, content_type)
            
            # Use detected content type if available
            if detected_type:
                content_type = detected_type
                
            if blob_text:
                if isinstance(blob_text, str) and blob_text.startswith("["):
                    status = blob_text
                else:
                    status = 'Decoded'
            else:
                status = 'Failed to decode'
                
            return create_result_dict(row, decompressed, content_type, encoding, blob_text, status)
        else:
            status = "Decompression returned None"
            return create_result_dict(row, None, None, None, None, status)
            
    except Exception as e:
        return create_result_dict(row, status=f"Error: {str(e)}")
    
    
def create_result_dict(row, decompressed=None, content_type=None, encoding=None, blob_text=None, status="Error"):
    """Fixed version to handle all type edge cases"""
    # Ensure status is always a string
    if not isinstance(status, str):
        status = str(status)
        
    # Safe string encoding with better type checking
    def safe_encode(text):
        if isinstance(text, str):
            return text.encode('utf-8', errors='ignore').decode('utf-8')
        elif text is not None:
            try:
                return str(text)
            except:
                return "[Non-string content]"
        return text
    
    # Handle edge cases for numeric fields
    event_id = row['EVENT_ID']
    updt_id = row['UPDT_ID']
    updt_task = row['UPDT_TASK'] 
    updt_cnt = row['UPDT_CNT']
    updt_applctx = row['UPDT_APPLCTX']
    
    # Calculate sizes safely
    binary_size = None
    if decompressed is not None:
        try:
            binary_size = len(decompressed)
        except:
            binary_size = -1  # Use a sentinel value for error
    
    text_length = None
    if blob_text is not None:
        try:
            text_length = len(blob_text)
        except:
            text_length = -1  # Use a sentinel value for error
    
    return {
        "EVENT_ID": event_id,
        "VALID_UNTIL_DT_TM": row['VALID_UNTIL_DT_TM'],
        "VALID_FROM_DT_TM": row['VALID_FROM_DT_TM'],
        "UPDT_DT_TM": row['UPDT_DT_TM'],
        "UPDT_ID": updt_id,
        "UPDT_TASK": updt_task,
        "UPDT_CNT": updt_cnt,
        "UPDT_APPLCTX": updt_applctx,
        "LAST_UTC_TS": row['LAST_UTC_TS'],
        "ADC_UPDT": row['ADC_UPDT'],
        "BLOB_BINARY": decompressed,
        "CONTENT_TYPE": content_type,
        "ENCODING": encoding,
        "BLOB_TEXT": safe_encode(blob_text),
        "BINARY_SIZE": binary_size,
        "TEXT_LENGTH": text_length,
        "STATUS": status
    }


# Define a pandas UDF to process rows
@pandas_udf(output_schema, PandasUDFType.GROUPED_MAP)
def process_rows(pdf):
    results = []
    global failed_event_ids
    
    for _, row in pdf.iterrows():
        try:
            result = process_single_row(row)
            status = result['STATUS']
            if isinstance(status, str) and (status.startswith('Error:') or status.startswith('[')):
                failed_event_ids.add(row['EVENT_ID'])
            results.append(result)
        except Exception as e:
            failed_event_ids.add(row['EVENT_ID'])
            results.append(create_result_dict(row, status=f"Error: Unexpected - {str(e)}"))
    
    return pd.DataFrame(results)






In [0]:


class ThreadTimeoutError(Exception):
    pass

@contextmanager
def timeout(seconds):
    """Thread-based timeout context manager that works in PySpark workers"""
    timer = None
    
    def raise_timeout():
        thread_id = threading.current_thread().ident
        for thread in threading.enumerate():
            if thread.ident == thread_id:
                raise ThreadTimeoutError(f"Timed out after {seconds} seconds")
    
    try:
        timer = threading.Timer(seconds, raise_timeout)
        timer.start()
        yield
    finally:
        if timer:
            timer.cancel()

def process_and_write_data(data, mode="append", retry_count=0, batch_timeout=3600):  # 60 minutes timeout
    import gc

    gc.collect()
    global failed_event_ids
    try:
        # Remove previously failed EVENT_IDs
        if failed_event_ids:
            data = data.filter(~F.col("EVENT_ID").isin(list(failed_event_ids)))
        
        if data.count() == 0:
            print("No data to process after removing failed EVENT_IDs.")
            return True

        with timeout(batch_timeout):
            # Process the data
            processed_df = data.groupby("EVENT_ID").apply(process_rows)
            
            # Write the processed data
            processed_df.write.mode(mode).insertInto(target_table_name)
        gc.collect()
        return True
    except ThreadTimeoutError:
        print(f"Batch processing timed out after {batch_timeout} seconds")
        # Here you might want to add logic to handle the timeout,
        # such as splitting the batch or marking it for later processing
        return False
    except Exception as e:
        print(f"Error in processing: {str(e)}")
        if retry_count < 3 and len(failed_event_ids) > 0:
            print(f"Retrying without {len(failed_event_ids)} failed EVENT_IDs. Retry count: {retry_count + 1}")
            return process_and_write_data(data, mode="append", retry_count=retry_count + 1)
        else:
            print("Max retries reached. Some data could not be processed.")
            return False




    

In [0]:
# Check if the target table exists and get the maximum ADC_UPDT
target_table_name = "4_prod.bronze.mill_blob_text"
max_adc_updt = get_max_adc_updt(target_table_name)

# Define batch sizes and their corresponding timeouts
batch_configs = [
    {"size": 2048 * 1024 * 1024, "timeout": 3600},  # 2GB, 60 minutes
    {"size": 64 * 1024 * 1024, "timeout": 600},     # 64 MB, 10 minutes
    {"size": 1 * 1024 * 1024, "timeout": 300},      # 1 MB, 5 minutes
    {"size": 1, "timeout": 120}                     # Single record, 2 minutes
]

# Get new data that's newer than the latest ADC_UPDT in the target table
print(f"\nProcessing data newer than ADC_UPDT: {max_adc_updt}")
df = spark.table("4_prod.raw.mill_ce_blob").filter(
    F.col("ADC_UPDT") > F.lit(max_adc_updt)
)

# If no new data, exit
if df.count() == 0:
    print("No new data to process")
    exit(0)

# Group and prepare the data
grouped_df = df.groupBy("EVENT_ID", "COMPRESSION_CD").agg(
    F.sort_array(
        F.collect_list(
            F.struct("BLOB_SEQ_NUM", "BLOB_CONTENTS")
        )
    ).getField("BLOB_CONTENTS").alias("BLOB_CONTENTS_LIST"),
    F.sum(F.col("BLOB_LENGTH").cast("bigint")).alias("TOTAL_BLOB_LENGTH"),
    F.count("*").alias("ROW_COUNT"),
    F.first("VALID_UNTIL_DT_TM").alias("VALID_UNTIL_DT_TM"),
    F.first("VALID_FROM_DT_TM").alias("VALID_FROM_DT_TM"),
    F.first("UPDT_DT_TM").alias("UPDT_DT_TM"),
    F.first("UPDT_ID").alias("UPDT_ID"),
    F.first("UPDT_TASK").alias("UPDT_TASK"),
    F.first("UPDT_CNT").alias("UPDT_CNT"),
    F.first("UPDT_APPLCTX").alias("UPDT_APPLCTX"),
    F.first("LAST_UTC_TS").alias("LAST_UTC_TS"),
    F.first("ADC_UPDT").alias("ADC_UPDT")
)

# Collect all EVENT_IDs that need processing, ordered by ADC_UPDT
size_df = grouped_df.select(
    "EVENT_ID", 
    "TOTAL_BLOB_LENGTH", 
    "ROW_COUNT",
    "ADC_UPDT"
).orderBy("ADC_UPDT", "EVENT_ID").collect()

remaining_events = {row['EVENT_ID'] for row in size_df}

# Process the data in batches
for batch_config in batch_configs:
    target_batch_size = batch_config["size"]
    batch_timeout = batch_config["timeout"]

    if not remaining_events:
        break
    
    print(f"\nProcessing with batch size: {target_batch_size/1024/1024:.2f} MB (timeout: {batch_timeout} seconds)")

    # Convert remaining_events set to list for current batch size processing
    current_size_df = [row for row in size_df if row['EVENT_ID'] in remaining_events]

    # Process all remaining events with current batch configuration
    current_batch = []
    current_batch_size = 0
    current_batch_row_count = 0
    batch_number = 1

    for row in current_size_df:
        event_id = row['EVENT_ID']
        blob_length = row['TOTAL_BLOB_LENGTH']
        row_count = row['ROW_COUNT']
        
        # Cap large blobs at 8MB for batch size calculation
        if blob_length > 8 * 1024 * 1024:
            effective_blob_length = 1
        else:
            effective_blob_length = blob_length

        current_batch.append(event_id)
        current_batch_size += effective_blob_length
        current_batch_row_count += row_count

        # Process batch when size limit is reached
        if (len(current_batch) > 1 and current_batch_size + effective_blob_length > target_batch_size) or len(current_batch) >= 100000:
            print(f"Processing batch {batch_number} with {len(current_batch)} EVENT_IDs")
            print(f"Batch size: {current_batch_size/1048576:.2f} MB, {current_batch_row_count} rows")
        
            start_time_batch = time.time()
            batch_df = grouped_df.filter(F.col("EVENT_ID").isin(current_batch))
        
            if process_and_write_data(batch_df, mode="append", batch_timeout=batch_timeout):
                remaining_events -= set(current_batch)
            else:
                print(f"Batch {batch_number} failed or timed out. Will try with smaller batch size.")
        
            end_time_batch = time.time()
            print(f"Batch duration: {end_time_batch - start_time_batch:.2f} seconds")
        
            batch_number += 1
            current_batch = []
            current_batch_size = 0
            current_batch_row_count = 0

    # Process final batch for current batch size if needed
    if current_batch:
        print(f"Processing final batch {batch_number} with {len(current_batch)} EVENT_IDs")
        print(f"Batch size: {current_batch_size/1048576:.2f} MB, {current_batch_row_count} rows")
    
        batch_df = grouped_df.filter(F.col("EVENT_ID").isin(current_batch))
        if process_and_write_data(batch_df, mode="append", batch_timeout=batch_timeout):
            remaining_events -= set(current_batch)

print(f"\nProcessing complete.")