# Protocol Document Processor (Spark Parallel)

Parses Protocol documents using PDF Table of Contents (TOC) for section extraction, with **Spark distributed parallelism** for high performance.

## Purpose
- Process Protocol documents uploaded via the UI
- Use PDF bookmarks/TOC to identify section boundaries (page ranges)
- Extract structured data from ALL sections with ordered content
- Store in `md_sandbox_documents` for UI preview (Schedule of Activities focus)

## Key Features
- **Spark Parallel Processing**: Uses `mapInPandas` for distributed section extraction
- **TOC-Based Section Detection**: Uses PyMuPDF to read PDF bookmarks/outlines
- **In-Memory/Local Temp Files**: Uses `/tmp` on worker nodes for fast I/O
- **Batch AI Processing**: Parallel `ai_parse_document` calls via Spark SQL
- **Ordered Content Storage**: Preserves text/table sequence within sections

## Performance
- Processes multiple sections in parallel across cluster workers
- Estimated 5-10x speedup compared to sequential processing

## Parameters
- `catalog_override`: Unity Catalog name
- `sandbox`: If 'true', only extract and store for preview (no downstream processing)


In [None]:
# Cell 0: Install required packages for PDF TOC extraction
%pip install PyMuPDF pdfminer.six PyPDF2 --quiet
# Note: Restart Python kernel if packages were just installed
# dbutils.library.restartPython()


In [None]:
# Cell 1: Setup and Parameters
import json
import re
import os
import uuid
from datetime import datetime
from typing import List, Dict, Iterator

# PySpark
import pandas as pd
from pyspark.sql import functions as F
from pyspark.sql.functions import col, lit, expr, current_timestamp
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType

# Widget parameters
dbutils.widgets.text("catalog_override", "aira_test")
dbutils.widgets.text("sandbox", "true")
dbutils.widgets.text("study_id", "")
dbutils.widgets.text("protocol_id", "")

catalog = dbutils.widgets.get("catalog_override")
sandbox_mode = dbutils.widgets.get("sandbox").lower() == "true"
study_id = dbutils.widgets.get("study_id") or None
protocol_id = dbutils.widgets.get("protocol_id") or None

print(f"Study ID: {study_id}")
print(f"Protocol ID: {protocol_id}")

# Schema names
bronze_schema = "bronze_md"
silver_schema = "silver_md"
gold_schema = "gold_md"

# Get job tracking values from upstream setup task
try:
    databricks_run_id = dbutils.jobs.taskValues.get(taskKey="setup", key="databricks_run_id")
    databricks_job_id = dbutils.jobs.taskValues.get(taskKey="setup", key="databricks_job_id")
    databricks_job_name = dbutils.jobs.taskValues.get(taskKey="setup", key="databricks_job_name")
    created_by_principal = dbutils.jobs.taskValues.get(taskKey="setup", key="created_by_principal")
except Exception as e:
    # Fallback for standalone testing
    databricks_run_id = "test_run_" + datetime.now().strftime("%Y%m%d_%H%M%S")
    databricks_job_id = "test_job"
    databricks_job_name = "nb_protocol_processor"
    created_by_principal = "test_user"

print(f"Catalog: {catalog}")
print(f"Sandbox Mode: {sandbox_mode}")
print(f"Run ID: {databricks_run_id}")
print(f"Job ID: {databricks_job_id}")


In [None]:
# Cell 2: Define protocol document sections table schema with ordered content support
# Table: silver_md.md_protocol_document_sections
# Added study_id and protocol_id for linking to study management

SECTIONS_TABLE = f"{catalog}.{silver_schema}.md_protocol_document_sections"
FILE_HISTORY_TABLE = f"{catalog}.{bronze_schema}.md_file_history"

# Schema for md_protocol_document_sections - supports ordered content within sections
# Each row represents one content item (text or table) within a section
# All fields are nullable (True) to prevent NullPointerException during DataFrame creation
sections_schema = StructType([
    StructField("section_id", StringType(), True),          # Unique ID for the section
    StructField("document_id", StringType(), True),         # FK to md_file_history
    StructField("study_id", StringType(), True),            # FK to md_study
    StructField("protocol_id", StringType(), True),         # FK to md_protocol
    StructField("section_name", StringType(), True),        # From TOC (e.g., "Schedule of Activities")
    StructField("section_level", IntegerType(), True),      # TOC hierarchy level
    StructField("page_start", IntegerType(), True),         # Section start page (0-based)
    StructField("page_end", IntegerType(), True),           # Section end page (0-based)
    StructField("content_order", IntegerType(), True),      # Sequence within section (1, 2, 3...)
    StructField("content_type", StringType(), True),        # "table" or "text"
    StructField("content_data", StringType(), True),        # JSON (table: headers/rows, text: content)
    StructField("created_ts", TimestampType(), True),       # When extraction was performed
    StructField("created_by_principal", StringType(), True),
    StructField("databricks_run_id", StringType(), True),
    StructField("databricks_job_id", StringType(), True),
    StructField("databricks_job_name", StringType(), True)
])

# Drop and recreate table with new schema (for development)
# In production, use ALTER TABLE to add columns
try:
    # Check if table exists and has old schema
    existing_cols = [c.name for c in spark.table(SECTIONS_TABLE).schema]
    if 'section_id' not in existing_cols:
        print(f"‚ö†Ô∏è Dropping old table to recreate with new schema...")
        spark.sql(f"DROP TABLE IF EXISTS {SECTIONS_TABLE}")
except Exception:
    pass  # Table doesn't exist, will be created below

# Create protocol documents details table with new schema
try:
    spark.sql(f"""
        CREATE TABLE IF NOT EXISTS {SECTIONS_TABLE} (
            section_id STRING COMMENT 'Unique ID (UUID) for the section',
            document_id STRING COMMENT 'FK to md_file_history.document_id',
            study_id STRING COMMENT 'FK to md_study.study_id',
            protocol_id STRING COMMENT 'FK to md_protocol.protocol_id',
            section_name STRING COMMENT 'Section name from TOC (e.g., Schedule of Activities)',
            section_level INT COMMENT 'TOC hierarchy level (1=top level)',
            page_start INT COMMENT 'Section start page (0-based)',
            page_end INT COMMENT 'Section end page (0-based)',
            content_order INT COMMENT 'Sequence within section (1, 2, 3...) to preserve order',
            content_type STRING COMMENT 'Content type: table or text',
            content_data STRING COMMENT 'JSON content - for table: headers/rows, for text: content string',
            created_ts TIMESTAMP COMMENT 'When extraction was performed',
            created_by_principal STRING COMMENT 'User or service principal',
            databricks_run_id STRING COMMENT 'Job run ID for lineage',
            databricks_job_id STRING COMMENT 'Job ID for lineage',
            databricks_job_name STRING COMMENT 'Job name for lineage'
        )
        USING DELTA
        COMMENT 'Protocol document details storing ordered content from extracted sections'
        TBLPROPERTIES (
            'delta.enableChangeDataFeed' = 'false'
        )
    """)
    print(f"‚úÖ Table {SECTIONS_TABLE} ready")
except Exception as e:
    print(f"‚ö†Ô∏è Table creation: {e}")


In [None]:
# Cell 3: TOC Extraction - Using Framework Module
# Import the PDFTOCExtractor from the clinical_data_standards_framework
# This provides a clean abstraction for PDF TOC extraction with section parsing

import fitz  # PyMuPDF - needed for page count fallback when no TOC found

from clinical_data_standards_framework.toc_utils import PDFTOCExtractor, Section

# Initialize the PDF TOC extractor
pdf_toc_extractor = PDFTOCExtractor()

print("‚úÖ TOC extraction loaded from clinical_data_standards_framework.toc_utils")


In [None]:
# Cell 4: Get pending Protocol documents
# (Legacy functions removed - now using framework module and inline logic in parallel phases)


In [None]:
# Cell 5: Get pending Protocol documents and build sections DataFrame

# ============================================================================
# Load section filter configuration from task values (set by setup task)
# ============================================================================
try:
    pipeline_config_json = dbutils.jobs.taskValues.get(taskKey="setup", key="pipeline_config")
    pipeline_config = json.loads(pipeline_config_json) if pipeline_config_json else {}
    
    ai_processing_config = pipeline_config.get('ai_processing', {})
    section_filter_config = ai_processing_config.get('section_filter', {})
    target_sections = ai_processing_config.get('target_sections', [])
    
    # Section filter settings
    filter_enabled = section_filter_config.get('enabled', False)
    max_level = section_filter_config.get('max_level', None)
    process_all_if_no_match = section_filter_config.get('process_all_if_no_match', False)
    
    # Build regex patterns from target sections
    section_patterns = [ts.get('pattern', '') for ts in target_sections if ts.get('pattern')]
    
    print(f"üìã Section Filter Config (from task values):")
    print(f"   Enabled: {filter_enabled}")
    print(f"   Max Level: {max_level or 'All'}")
    print(f"   Patterns: {section_patterns}")
    
except Exception as e:
    print(f"‚ö†Ô∏è Could not load config from task values, processing all sections: {e}")
    filter_enabled = False
    max_level = None
    section_patterns = []
    process_all_if_no_match = True

def section_matches_filter(section_title: str, patterns: List[str]) -> bool:
    """Check if section title matches any of the filter patterns."""
    if not patterns:
        return True
    for pattern in patterns:
        if re.search(pattern, section_title, re.IGNORECASE):
            return True
    return False

# Query for Protocol documents that are ready for processing
pending_docs_query = f"""
    SELECT 
        document_id,
        extracted_path,
        SPLIT(extracted_path, '/')[SIZE(SPLIT(extracted_path, '/')) - 1] as file_name,
        file_extension,
        document_tags,
        status,
        databricks_run_id as source_run_id
    FROM {FILE_HISTORY_TABLE}
    WHERE array_contains(document_tags, 'Protocol')
      AND status = 'READY_FOR_PROCESSING'
      AND file_extension IN ('.pdf', '.docx', '.doc')
    ORDER BY created_ts DESC
"""

pending_docs_df = spark.sql(pending_docs_query)
pending_count = pending_docs_df.count()

print(f"\nFound {pending_count} Protocol documents to process")
if pending_count > 0:
    pending_docs_df.show(truncate=False)

# Exit early if no documents
if pending_count == 0:
    print("No Protocol documents to process. Exiting.")
    dbutils.notebook.exit(json.dumps({"status": "success", "documents_processed": 0}))

# ============================================================================
# STEP 1: Extract TOC sections from all documents (on driver)
# This creates a DataFrame with one row per section
# Applies section filtering based on config (patterns and max_level)
# ============================================================================
print("\n" + "="*60)
print("STEP 1: Extracting TOC sections from all documents...")
if filter_enabled:
    print(f"   (Filtering to: {section_patterns})")
print("="*60)

sections_data = []
processed_doc_ids = []
failed_doc_ids = []
total_sections_found = 0
sections_after_filter = 0

for doc in pending_docs_df.collect():
    document_id = doc['document_id']
    extracted_path = doc['extracted_path']
    file_name = doc['file_name']
    
    print(f"\nüìÑ {file_name}")
    
    try:
        # Extract all sections from TOC using the framework extractor
        sections = pdf_toc_extractor.extract_sections(extracted_path)
        
        if not sections:
            print(f"  ‚ö†Ô∏è No TOC found, treating as single section")
            # Get total pages for full document
            with fitz.open(extracted_path) as doc_pdf:
                total_pages = doc_pdf.page_count
            sections = [Section(
                title="Full Document",
                level=1,
                page_start=0,
                page_end=total_pages - 1
            )]
        
        total_sections_found += len(sections)
        print(f"  üìã Found {len(sections)} sections in TOC")
        
        # Apply section filtering
        matched_sections = []
        for section in sections:
            # Apply max_level filter if configured
            if max_level and section.level > max_level:
                continue
            
            # Apply pattern filter if enabled
            if filter_enabled and section_patterns:
                if section_matches_filter(section.title, section_patterns):
                    matched_sections.append(section)
            else:
                matched_sections.append(section)
        
        if filter_enabled:
            print(f"  ‚úÖ After filter: {len(matched_sections)} sections match patterns")
        
        # If no matches and process_all_if_no_match is False, skip this document
        if filter_enabled and not matched_sections and not process_all_if_no_match:
            print(f"  ‚ö†Ô∏è No sections match filter patterns, skipping document")
            continue
        
        # Use matched sections (or all if filter not enabled/no matches)
        sections_to_process = matched_sections if matched_sections else (sections if process_all_if_no_match else [])
        
        for section in sections_to_process:
            section_id = str(uuid.uuid4())
            sections_data.append({
                'section_id': section_id,
                'document_id': document_id,
                'source_path': extracted_path,
                'section_title': section.title,
                'section_level': int(section.level),
                'page_start': int(section.page_start),
                'page_end': int(section.page_end) if section.page_end else 999
            })
        
        sections_after_filter += len(sections_to_process)
        processed_doc_ids.append(document_id)
        
    except Exception as e:
        print(f"  ‚ùå Error extracting TOC: {e}")
        failed_doc_ids.append(document_id)

print(f"\n‚úÖ Sections: {total_sections_found} found ‚Üí {sections_after_filter} after filter")
print(f"‚úÖ Created {len(sections_data)} section records from {len(processed_doc_ids)} documents")

# Create sections DataFrame
sections_schema = StructType([
    StructField("section_id", StringType(), False),
    StructField("document_id", StringType(), False),
    StructField("source_path", StringType(), False),
    StructField("section_title", StringType(), True),
    StructField("section_level", IntegerType(), True),
    StructField("page_start", IntegerType(), True),
    StructField("page_end", IntegerType(), True)
])

sections_df = spark.createDataFrame(sections_data, schema=sections_schema)

# Repartition for parallel processing (aim for ~4-8 sections per partition)
num_partitions = max(1, len(sections_data) // 6)
num_partitions = min(num_partitions, 16)  # Cap at 16 partitions
sections_df = sections_df.repartition(num_partitions)

print(f"üìä Sections DataFrame: {sections_df.count()} rows, {num_partitions} partitions")


In [None]:
# Cell 6: PHASE 1 - Extract PDF sections to temp Volume files (parallel, no Spark SQL)

# ============================================================================
# TWO-PHASE APPROACH for Serverless Compute:
# Phase 1: mapInPandas extracts PDF sections to temp files (parallel, no SQL)
# Phase 2: Spark SQL calls ai_parse_document on all files (parallel via SQL)
# ============================================================================

# Get temp volume path from first document
first_doc = pending_docs_df.first()
base_volume_path = '/'.join(first_doc['extracted_path'].split('/')[:-2])
TEMP_SECTIONS_PATH = f"{base_volume_path}/temp_sections/{databricks_run_id}"

# Create temp directory
os.makedirs(TEMP_SECTIONS_PATH, exist_ok=True)
print(f"üìÅ Temp sections path: {TEMP_SECTIONS_PATH}")

# Schema for Phase 1 output (extracted files)
extraction_schema = StructType([
    StructField("section_id", StringType(), False),
    StructField("document_id", StringType(), False),
    StructField("section_title", StringType(), True),
    StructField("section_level", IntegerType(), True),
    StructField("page_start", IntegerType(), True),
    StructField("page_end", IntegerType(), True),
    StructField("temp_file_path", StringType(), True),
    StructField("extraction_status", StringType(), True),
    StructField("error_message", StringType(), True)
])

# Broadcast the temp path to workers
temp_path_broadcast = TEMP_SECTIONS_PATH

def extract_sections_to_files(pdf_iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
    """
    PHASE 1: Extract PDF sections to temp Volume files.
    NO Spark SQL - just pure Python/Pandas file operations.
    Runs in parallel across worker nodes.
    """
    import os
    from PyPDF2 import PdfReader, PdfWriter
    
    for pdf in pdf_iterator:
        results = []
        
        for _, row in pdf.iterrows():
            section_id = row['section_id']
            document_id = row['document_id']
            source_path = row['source_path']
            section_title = row['section_title']
            section_level = row['section_level']
            page_start = row['page_start']
            page_end = row['page_end']
            
            # Use Volume path for temp files (accessible by Spark SQL in Phase 2)
            temp_file_path = f"{temp_path_broadcast}/{section_id}.pdf"
            
            try:
                # Extract section pages to temp file
                with open(source_path, 'rb') as infile:
                    reader = PdfReader(infile)
                    writer = PdfWriter()
                    
                    total_pages = len(reader.pages)
                    actual_end = min(page_end, total_pages - 1)
                    
                    for page_num in range(page_start, actual_end + 1):
                        writer.add_page(reader.pages[page_num])
                    
                    # Ensure directory exists
                    os.makedirs(os.path.dirname(temp_file_path), exist_ok=True)
                    
                    with open(temp_file_path, 'wb') as outfile:
                        writer.write(outfile)
                
                results.append({
                    'section_id': section_id,
                    'document_id': document_id,
                    'section_title': section_title,
                    'section_level': section_level,
                    'page_start': page_start,
                    'page_end': page_end,
                    'temp_file_path': temp_file_path,
                    'extraction_status': 'EXTRACTED',
                    'error_message': None
                })
                
            except Exception as e:
                results.append({
                    'section_id': section_id,
                    'document_id': document_id,
                    'section_title': section_title,
                    'section_level': section_level,
                    'page_start': page_start,
                    'page_end': page_end,
                    'temp_file_path': None,
                    'extraction_status': 'FAILED',
                    'error_message': str(e)[:500]
                })
        
        yield pd.DataFrame(results) if results else pd.DataFrame(columns=[
            'section_id', 'document_id', 'section_title', 'section_level',
            'page_start', 'page_end', 'temp_file_path', 'extraction_status', 'error_message'
        ])

print("‚úÖ Phase 1 extraction function defined")


In [None]:
# Cell 7: Execute PHASE 1 - Extract sections to temp files (parallel)

print("\n" + "="*60)
print("PHASE 1: Extracting PDF sections in parallel...")
print("="*60)

import time
phase1_start = time.time()

# Apply Phase 1 extraction function via mapInPandas
extracted_df = sections_df.mapInPandas(
    extract_sections_to_files,
    schema=extraction_schema
)

# Write to temp view to materialize the extraction
extracted_df.createOrReplaceTempView("extracted_sections_temp")

# Force materialization and get counts
extraction_results = spark.sql("""
    SELECT 
        extraction_status,
        COUNT(*) as count
    FROM extracted_sections_temp
    GROUP BY extraction_status
""").collect()

phase1_elapsed = time.time() - phase1_start

print(f"\n‚úÖ Phase 1 complete in {phase1_elapsed:.1f} seconds")
for row in extraction_results:
    print(f"  {row['extraction_status']}: {row['count']} sections")

# Get successfully extracted sections
extracted_count = spark.sql("SELECT COUNT(*) FROM extracted_sections_temp WHERE extraction_status = 'EXTRACTED'").collect()[0][0]
print(f"\nüìä {extracted_count} sections ready for AI parsing")


In [None]:
# Cell 8: Execute PHASE 2 - Call ai_parse_document via DataFrame API (parallel)

print("\n" + "="*60)
print("PHASE 2: Parsing sections with ai_parse_document...")
print("="*60)

phase2_start = time.time()

# Get the list of extracted file paths
extracted_paths_df = spark.sql("""
    SELECT section_id, document_id, section_title, section_level, 
           page_start, page_end, temp_file_path
    FROM extracted_sections_temp
    WHERE extraction_status = 'EXTRACTED'
""")

extracted_paths = extracted_paths_df.collect()
print(f"üìÑ Processing {len(extracted_paths)} extracted sections...")

if extracted_paths:
    # Create metadata DataFrame for joining
    paths_df = spark.createDataFrame(
        [(row['section_id'], row['document_id'], row['section_title'], 
          row['section_level'], row['page_start'], row['page_end'], 
          row['temp_file_path']) for row in extracted_paths],
        ['section_id', 'document_id', 'section_title', 'section_level', 
         'page_start', 'page_end', 'temp_file_path']
    )
    
    # Read all binary files using DataFrame API (Spark parallelizes this)
    file_paths = [row['temp_file_path'] for row in extracted_paths]
    binary_df = spark.read.format("binaryFile").load(file_paths)
    
    # Join binary content with section metadata using filename (equi-join)
    # Extract filename from path for matching (handles different path prefixes like file:/, dbfs:/)
    binary_df = binary_df.withColumn(
        "filename", 
        F.element_at(F.split("path", "/"), -1)
    )
    
    paths_df = paths_df.withColumn(
        "filename",
        F.element_at(F.split("temp_file_path", "/"), -1)
    )
    
    # Equi-join on filename (fast hash join)
    binary_with_meta = binary_df.join(
        paths_df,
        binary_df.filename == paths_df.filename,
        "inner"
    )
    
    # Apply ai_parse_document using expr (runs in parallel across cluster)
    parsed_df = binary_with_meta.withColumn(
        "parsed_result",
        expr("ai_parse_document(content, map('version', '2.0', 'descriptionElementTypes', '*'))")
    )
    
    # Extract elements as string for processing
    parsed_elements_df = parsed_df.selectExpr(
        "section_id",
        "document_id",
        "section_title",
        "section_level",
        "page_start",
        "page_end",
        "temp_file_path",
        "CAST(parsed_result:document:elements AS STRING) as elements_json",
        "CAST(parsed_result:error_status AS STRING) as error_status"
    )
    
    # Collect results NOW to avoid re-computation in Phase 3
    # This is the single point where ai_parse_document is executed
    parsed_rows = parsed_elements_df.collect()
    
    phase2_elapsed = time.time() - phase2_start
    parsed_count = len(parsed_rows)
    
    print(f"\n‚úÖ Phase 2 complete in {phase2_elapsed:.1f} seconds")
    print(f"üìä Parsed {parsed_count} sections")
else:
    print("‚ö†Ô∏è No sections to parse")
    phase2_elapsed = 0
    parsed_rows = []  # Empty list for Phase 3


In [None]:
# Cell 9: PHASE 3 - Process parsed results and write to sandbox table

print("\n" + "="*60)
print("PHASE 3: Processing parsed content and writing to sandbox...")
print("="*60)

phase3_start = time.time()

# Use parsed_rows collected in Phase 2 (avoids re-computation!)
# Filter out errors
parsed_rows_success = [row for row in parsed_rows 
                       if row['error_status'] is None or row['error_status'] == 'null']

print(f"üìÑ Processing {len(parsed_rows_success)} successfully parsed sections...")

# Process elements and create content rows
all_content_rows_data = []

for row in parsed_rows_success:
    section_id = row['section_id']
    document_id = row['document_id']
    section_title = row['section_title']
    section_level = row['section_level']
    page_start = row['page_start']
    page_end = row['page_end']
    elements_json = row['elements_json']
    
    if not elements_json:
        continue
    
    try:
        elements = json.loads(elements_json)
        
        # Process elements in order
        ordered_content = []
        for elem in elements:
            elem_type = elem.get('type', '').lower()
            content = elem.get('content', '')
            
            if elem_type == 'table':
                if isinstance(content, str) and '<table' in content.lower():
                    # Parse HTML table
                    th_matches = re.findall(r'<th[^>]*>(.*?)</th>', content, re.IGNORECASE | re.DOTALL)
                    headers = [re.sub(r'<[^>]+>', '', h).strip() for h in th_matches]
                    
                    tr_matches = re.findall(r'<tr[^>]*>(.*?)</tr>', content, re.IGNORECASE | re.DOTALL)
                    rows = []
                    for tr in tr_matches:
                        if '<th' in tr.lower() and headers:
                            continue
                        td_matches = re.findall(r'<td[^>]*>(.*?)</td>', tr, re.IGNORECASE | re.DOTALL)
                        if td_matches:
                            row_data = [re.sub(r'<[^>]+>', '', td).strip() for td in td_matches]
                            rows.append(row_data)
                    
                    ordered_content.append({
                        'type': 'table',
                        'data': {
                            'html': content,
                            'headers': headers,
                            'rows': rows,
                            'row_count': len(rows),
                            'column_count': len(headers)
                        }
                    })
            elif elem_type in ('paragraph', 'text', 'title', 'section_header', 'line', 'caption'):
                text_val = elem.get('text', '') or content
                if isinstance(text_val, str) and text_val.strip():
                    # Merge consecutive text elements
                    if ordered_content and ordered_content[-1]['type'] == 'text':
                        ordered_content[-1]['data']['content'] += '\n' + text_val
                    else:
                        ordered_content.append({
                            'type': 'text',
                            'data': {'content': text_val}
                        })
        
        # Create content rows (including study_id and protocol_id from job parameters)
        for content_order, content_item in enumerate(ordered_content, start=1):
            all_content_rows_data.append({
                "section_id": section_id,
                "document_id": document_id,
                "study_id": study_id,
                "protocol_id": protocol_id,
                "section_name": section_title,
                "section_level": section_level,
                "page_start": page_start,
                "page_end": page_end,
                "content_order": content_order,
                "content_type": content_item['type'],
                "content_data": json.dumps(content_item['data']),
                "created_ts": datetime.now(),
                "created_by_principal": created_by_principal,
                "databricks_run_id": databricks_run_id,
                "databricks_job_id": databricks_job_id,
                "databricks_job_name": databricks_job_name
            })
    except Exception as e:
        print(f"  ‚ö†Ô∏è Error processing section {section_id}: {e}")

print(f"‚úÖ Created {len(all_content_rows_data)} content rows")

# Write to protocol document sections table
if all_content_rows_data:
    # Debug: Print first row to verify data types
    print("\nüìã DEBUG - First row sample:")
    first_row = all_content_rows_data[0]
    for key, value in first_row.items():
        val_type = type(value).__name__
        val_preview = str(value)[:50] if value else "None"
        print(f"  {key}: {val_preview}... ({val_type})")
    
    # Check for any NULL in critical fields
    null_issues = []
    for i, row in enumerate(all_content_rows_data):
        if row.get('section_id') is None:
            null_issues.append(f"Row {i}: section_id is None")
        if row.get('document_id') is None:
            null_issues.append(f"Row {i}: document_id is None")
    
    if null_issues:
        print(f"\n‚ö†Ô∏è Found {len(null_issues)} NULL issues:")
        for issue in null_issues[:5]:  # Show first 5
            print(f"  {issue}")
    else:
        print("\n‚úÖ No NULL issues in critical fields")
    
    # Don't pass explicit schema - let Spark infer it to avoid NULL serialization issues
    # Use mergeSchema to handle any schema evolution
    content_df = spark.createDataFrame(all_content_rows_data)
    
    # Explicitly cast integer columns to INT to avoid LONG vs INT schema conflicts
    # Spark may infer these as LONG (64-bit), but table schema expects INT (32-bit)
    from pyspark.sql.functions import col
    content_df = content_df \
        .withColumn("section_level", col("section_level").cast("int")) \
        .withColumn("page_start", col("page_start").cast("int")) \
        .withColumn("page_end", col("page_end").cast("int")) \
        .withColumn("content_order", col("content_order").cast("int"))
    
    content_df.write \
        .mode("append") \
        .option("mergeSchema", "true") \
        .saveAsTable(SECTIONS_TABLE)
    
    all_content_rows = len(all_content_rows_data)
    print(f"‚úÖ Wrote {all_content_rows} rows to {SECTIONS_TABLE}")
else:
    all_content_rows = 0
    print("‚ö†Ô∏è No content to write")

phase3_elapsed = time.time() - phase3_start
print(f"\n‚úÖ Phase 3 complete in {phase3_elapsed:.1f} seconds")


In [None]:
# Cell 10: Cleanup temp files

print("\n" + "="*60)
print("Cleaning up temp files...")
print("="*60)

import shutil
import glob

try:
    # Check if temp directory exists
    if os.path.exists(TEMP_SECTIONS_PATH):
        # Count files before deletion
        temp_files = glob.glob(f"{TEMP_SECTIONS_PATH}/*.pdf")
        file_count = len(temp_files)
        print(f"üìÅ Found {file_count} temp PDF files to delete")
        
        # Remove entire temp directory
        shutil.rmtree(TEMP_SECTIONS_PATH)
        print(f"‚úÖ Removed temp directory: {TEMP_SECTIONS_PATH}")
    else:
        print("‚ÑπÔ∏è No temp directory to clean up")
        
    # Also check parent temp_sections folder and clean up old runs if empty
    parent_temp_path = os.path.dirname(TEMP_SECTIONS_PATH)
    if os.path.exists(parent_temp_path):
        remaining_dirs = os.listdir(parent_temp_path)
        if not remaining_dirs:
            os.rmdir(parent_temp_path)
            print(f"‚úÖ Removed empty parent directory: {parent_temp_path}")
        else:
            print(f"‚ÑπÔ∏è Parent directory has {len(remaining_dirs)} other run directories")
            
except Exception as e:
    print(f"‚ö†Ô∏è Error cleaning up temp files: {e}")

# Print total processing time
total_elapsed = phase1_elapsed + phase2_elapsed + phase3_elapsed
print(f"\nüìä Total processing time: {total_elapsed:.1f} seconds")


In [None]:
# Cell 11: Update document status in md_file_history

# processed_doc_ids and failed_doc_ids are defined in Cell 6 (STEP 1)

if processed_doc_ids:
    ids_str = "', '".join(processed_doc_ids)
    update_query = f"""
        UPDATE {FILE_HISTORY_TABLE}
        SET 
            status = 'PROTOCOL_COMPLETED',
            status_timestamp = current_timestamp(),
            last_updated_ts = current_timestamp(),
            last_updated_by_principal = '{created_by_principal}',
            databricks_job_id = '{databricks_job_id}',
            databricks_job_name = '{databricks_job_name}',
            databricks_run_id = '{databricks_run_id}'
        WHERE document_id IN ('{ids_str}')
    """
    spark.sql(update_query)
    print(f"‚úÖ Updated {len(processed_doc_ids)} documents to PROTOCOL_COMPLETED")

if failed_doc_ids:
    ids_str = "', '".join(failed_doc_ids)
    update_query = f"""
        UPDATE {FILE_HISTORY_TABLE}
        SET 
            status = 'PROTOCOL_FAILED',
            status_timestamp = current_timestamp(),
            last_updated_ts = current_timestamp(),
            last_updated_by_principal = '{created_by_principal}',
            databricks_job_id = '{databricks_job_id}',
            databricks_job_name = '{databricks_job_name}',
            databricks_run_id = '{databricks_run_id}'
        WHERE document_id IN ('{ids_str}')
    """
    spark.sql(update_query)
    print(f"‚ö†Ô∏è Updated {len(failed_doc_ids)} documents to PROTOCOL_FAILED")


In [None]:
# Cell 12: Update protocol draft status and exit with summary

# Update md_protocol_draft status if protocol_id is provided
if protocol_id:
    protocol_draft_table = f"{catalog}.{silver_schema}.md_protocol_draft"
    try:
        if len(failed_doc_ids) > 0 and len(processed_doc_ids) == 0:
            new_status = "FAILED"
        elif all_content_rows > 0:
            new_status = "SECTIONS_EXTRACTED"
        else:
            new_status = "PARSING"
        
        spark.sql(f"""
            UPDATE {protocol_draft_table}
            SET status = '{new_status}',
                section_count = {all_content_rows},
                last_updated_by_principal = '{created_by_principal}',
                last_updated_ts = current_timestamp()
            WHERE protocol_id = '{protocol_id}'
        """)
        print(f"‚úÖ Updated protocol {protocol_id} status to {new_status}")
    except Exception as e:
        print(f"‚ö†Ô∏è Could not update protocol draft status: {e}")

# Determine overall status based on results
if len(failed_doc_ids) > 0 and len(processed_doc_ids) == 0:
    overall_status = "failed"
elif len(failed_doc_ids) > 0:
    overall_status = "partial_success"
else:
    overall_status = "success"

summary = {
    "status": overall_status,
    "sandbox_mode": sandbox_mode,
    "documents_processed": len(processed_doc_ids),
    "documents_failed": len(failed_doc_ids),
    "content_rows_extracted": all_content_rows,
    "sections_table": SECTIONS_TABLE,
    "run_id": databricks_run_id,
    "processing_mode": "two_phase_parallel",
    "total_time_seconds": round(total_elapsed, 1)
}

print("\n" + "="*60)
print("Protocol Processor Summary (Two-Phase Parallel)")
print("="*60)
for key, value in summary.items():
    print(f"  {key}: {value}")
print("="*60)

# Fail the job if all documents failed
if overall_status == "failed":
    raise Exception(f"Protocol processing failed for all {len(failed_doc_ids)} documents.")

dbutils.notebook.exit(json.dumps(summary))
