In [0]:
%sql
USE CATALOG `nokia-assginment-catalog`;
-- drop schema patent_data cascade;

In [0]:
# Try to create a widget to control schema dropping
try:
    dbutils.widgets.dropdown("drop_patent_data_schema", "false", ["true", "false"], "Drop schema patent_data cascade")
    drop_patent_data_schema = dbutils.widgets.get("drop_patent_data_schema") == "true"
except:
    # Default to not dropping schema in job mode
    drop_patent_data_schema = False

print(f"Drop patent_data schema setting: {drop_patent_data_schema}")

# Execute SQL to drop schema if requested
if drop_patent_data_schema:
    try:
        print("Dropping schema patent_data cascade...")
        spark.sql("DROP SCHEMA IF EXISTS patent_data CASCADE")
        print("Schema patent_data successfully dropped")
    except Exception as e:
        print(f"Error dropping schema: {str(e)}")

In [0]:
from pyspark.sql import SparkSession
from delta.tables import DeltaTable
import traceback
import json
import uuid

def initialize_spark():
    """Initialize Spark session with Delta Lake support"""
    return SparkSession.builder \
        .appName("Patent Delta Tables Registration") \
        .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
        .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
        .getOrCreate()

def check_path_exists(path):
    """Check if a path exists and is accessible"""
    try:
        dbutils.fs.ls(path)
        return True
    except:
        return False

def is_delta_table(spark, path):
    """Check if a path is a valid Delta table"""
    try:
        DeltaTable.forPath(spark, path)
        return True
    except:
        return False

def check_new_data_processed():
    """
    Check if new data was processed in the previous gold layer step
    
    Returns:
        bool: Whether new data was processed
    """
    try:
        # Method 1: Check processing metadata table
        spark = initialize_spark()
        
        # Check if the processing_metadata table exists
        database_name = "patent_data"
        try:
            metadata_df = spark.sql(f"SELECT * FROM {database_name}.processing_metadata ORDER BY processing_timestamp DESC LIMIT 1")
            if metadata_df.count() > 0:
                last_record = metadata_df.first()
                new_data_processed = last_record.new_data_processed
                print(f"Found processing metadata: new_data_processed={new_data_processed}")
                return new_data_processed
        except:
            print("Could not query processing_metadata table")
        
        # Method 2: Check processing status Delta file
        try:
            processing_status_path = "/Volumes/nokia-assginment-catalog/processing_status/status"
            if check_path_exists(processing_status_path):
                status_df = spark.read.format("delta").load(processing_status_path)
                if status_df.count() > 0:
                    status = status_df.first().new_data_processed
                    print(f"Found processing status file: new_data_processed={status}")
                    return status
        except:
            print("Could not read processing status file")
        
        # Method 3: Check for previous notebook result
        try:
            # Try to get result from previous notebook
            prev_result_str = dbutils.notebook.entry_point.getDbutils().notebook().getContext().parentContext().get("result")
            if prev_result_str:
                prev_result = json.loads(prev_result_str)
                if "new_data_processed" in prev_result:
                    status = prev_result["new_data_processed"]
                    print(f"Found previous notebook result: new_data_processed={status}")
                    return status
        except:
            print("Could not get previous notebook result")
        
        # Default: If we can't determine, assume there might be new data
        print("Could not determine if new data was processed, assuming yes")
        return True
        
    except Exception as e:
        print(f"Error checking for new data: {str(e)}")
        print(traceback.format_exc())
        # If there's an error, assume there might be new data
        return True

def register_and_upsert_table(spark, delta_path, table_name, database_name="patent_data"):
    """
    Register a Delta table if it doesn't exist and upsert data
    
    Args:
        spark: SparkSession
        delta_path: Path to Delta table
        table_name: Name of the table to create/update
        database_name: Database to create the table in
    
    Returns:
        dict: Operation result
    """
    try:
        # Debug - what's in the gold delta path?
        print(f"Contents of {delta_path}:")
        path_files = dbutils.fs.ls(delta_path)
        for file in path_files[:5]:  # Show first 5 items
            print(f"  - {file.name} ({file.size} bytes)")
        
        # Debug - how many batches are in silver layer
        try:
            silver_path = "/Volumes/nokia-assginment-catalog/silver/patent_data"
            silver_dirs = [d for d in dbutils.fs.ls(silver_path) if d.isDir() and not d.name.startswith('_')]
            print(f"Silver layer has {len(silver_dirs)} batch directories")
        except:
            print("Could not access silver layer to count batches")
        
        # Debug - how many batches have checkpoints
        try:
            checkpoint_location = "/Volumes/nokia-assginment-catalog/checkpoints/checkpoints_data/gold_autoloader/"
            if check_path_exists(checkpoint_location):
                checkpoints = [d for d in dbutils.fs.ls(checkpoint_location) if not d.name.startswith('_')]
                print(f"Found {len(checkpoints)} checkpoint files")
        except:
            print("Could not check checkpoint files")

        # Create database if it doesn't exist
        spark.sql(f"CREATE DATABASE IF NOT EXISTS {database_name}")
        
        # Fully qualified table name
        full_table_name = f"{database_name}.{table_name}"
        
        # Read the Delta data (source)
        source_df = spark.read.format("delta").load(delta_path)
        
        # Check if source has data and print more details for debugging
        source_count = source_df.count()
        print(f"Source Delta file at {delta_path} has {source_count} records")
        
        if source_count == 0:
            return {
                "status": "warning",
                "message": f"Source Delta file at {delta_path} is empty"
            }
            
        # Check if table already exists
        table_exists = False
        existing_tables = [t.name for t in spark.catalog.listTables(database_name)]
        if table_name in existing_tables:
            table_exists = True
            
        # For individual tables, use overwrite mode instead of merge
        is_individual_table = "_individual" in table_name
        
        if not table_exists:
            # Create table if it doesn't exist
            print(f"Creating table {full_table_name} from {delta_path}")
            source_df.write.format("delta").mode("overwrite").saveAsTable(full_table_name)
            
            # Configure table properties for optimization
            spark.sql(f"""
                ALTER TABLE {full_table_name}
                SET TBLPROPERTIES (
                    'delta.autoOptimize.optimizeWrite' = 'true',
                    'delta.autoOptimize.autoCompact' = 'true',
                    'delta.enableChangeDataFeed' = 'true',
                    'delta.logRetentionDuration' = '30 days',
                    'delta.tuneFileSizesForRewrites' = 'true'
                )
            """)
            
            print(f"Table {full_table_name} created successfully with {source_count} rows")
            return {
                "status": "created",
                "message": f"Table {full_table_name} created with {source_count} rows"
            }
        elif is_individual_table:
            # For individual tables, keep a record of how many rows we're about to replace
            target_df = spark.read.table(full_table_name)
            target_count = target_df.count()
            print(f"Replacing {target_count} existing rows with {source_count} rows in {full_table_name}")
            
            # If source has fewer records than target, this could indicate a problem
            if source_count < target_count and target_count > 0:
                print(f"WARNING: Source has fewer records ({source_count}) than existing table ({target_count})")
            
            # For individual tables, use complete replacement instead of merge
            source_df.write.format("delta").mode("overwrite").saveAsTable(full_table_name)
            print(f"Data in {full_table_name} replaced successfully with {source_count} rows")
            return {
                "status": "replaced",
                "message": f"Table {full_table_name} completely replaced with {source_count} rows (previous: {target_count})"
            }
        else:
            # For regular tables, perform upsert
            if "publication_number" not in source_df.columns:
                return {
                    "status": "warning",
                    "message": f"Source data doesn't have publication_number column"
                }
            
            # Get target table and count before merge
            target_table = DeltaTable.forName(spark, full_table_name)
            target_count_before = spark.read.table(full_table_name).count()
            
            print(f"Performing upsert to {full_table_name} from {delta_path}")
            print(f"Target table has {target_count_before} rows before merge")
            
            # Sample a few publication numbers from source for debugging
            sample_pubs = source_df.select("publication_number").limit(5).collect()
            print(f"Sample publication numbers from source: {[row.publication_number for row in sample_pubs]}")
            
            # Perform merge operation
            target_table.alias("target").merge(
                source_df.alias("source"),
                "target.publication_number = source.publication_number"
            ).whenMatchedUpdateAll(
            ).whenNotMatchedInsertAll(
            ).execute()
            
            # Count after merge to see if anything changed
            target_count_after = spark.read.table(full_table_name).count()
            rows_changed = target_count_after - target_count_before
            
            print(f"Upsert to {full_table_name} completed. Rows before: {target_count_before}, after: {target_count_after}")
            print(f"Net change in rows: {rows_changed}")
            
            return {
                "status": "updated",
                "message": f"Table {full_table_name} updated with upsert. Net new rows: {rows_changed}"
            }
            
    except Exception as e:
        error_message = f"Error processing {table_name}: {str(e)}"
        print(error_message)
        print(traceback.format_exc())
        return {
            "status": "error",
            "message": error_message
        }

def register_and_upsert_all_tables(gold_base_path="/Volumes/nokia-assginment-catalog/gold", 
                                   database_name="patent_data"):
    """
    Process all patent gold tables - register if needed and upsert data
    
    Args:
        gold_base_path: Base path to gold Delta tables
        database_name: Database to create tables in
        
    Returns:
        dict: Results for each table
    """
    spark = initialize_spark()
    results = {}
    
    # Create control table if it doesn't exist
    try:
        spark.sql(f"""
        CREATE TABLE IF NOT EXISTS {database_name}.patent_update_control (
            update_id STRING,
            update_timestamp TIMESTAMP,
            result STRING
        ) USING DELTA
        """)
    except Exception as e:
        print(f"Warning: Could not create control table: {str(e)}")
    
    # Define tables to process
    tables = [
        {"path": f"{gold_base_path}/patents", "name": "patents"},
        {"path": f"{gold_base_path}/ipc_classifications", "name": "ipc_classifications"},
        {"path": f"{gold_base_path}/ipc_individual", "name": "ipc_individual"},
        {"path": f"{gold_base_path}/inventors", "name": "inventors"},
        {"path": f"{gold_base_path}/inventors_individual", "name": "inventors_individual"},
        {"path": f"{gold_base_path}/applicants", "name": "applicants"},
        {"path": f"{gold_base_path}/applicants_individual", "name": "applicants_individual"},
        {"path": f"{gold_base_path}/us_applicants", "name": "us_applicants"},
        {"path": f"{gold_base_path}/us_applicants_individual", "name": "us_applicants_individual"},
        {"path": f"{gold_base_path}/claims", "name": "claims"},
        {"path": f"{gold_base_path}/claims_individual", "name": "claims_individual"},
        {"path": f"{gold_base_path}/complete_patents", "name": "complete_patents"}
    ]
    
    # Register and upsert each table
    for table_info in tables:
        delta_path = table_info["path"]
        table_name = table_info["name"]
        
        # Check if Delta path exists
        if not check_path_exists(delta_path):
            print(f"Delta path does not exist: {delta_path}")
            results[table_name] = {"status": "skipped", "message": "Delta path not found"}
            continue
            
        # Check if it's a valid Delta table
        if not is_delta_table(spark, delta_path):
            print(f"Path is not a valid Delta table: {delta_path}")
            results[table_name] = {"status": "skipped", "message": "Not a valid Delta table"}
            continue
            
        # Register and upsert the table
        result = register_and_upsert_table(
            spark, 
            delta_path, 
            table_name, 
            database_name=database_name
        )
        
        results[table_name] = result
    
    # Log this batch of updates in the control table
    try:
        result_json = json.dumps(results)
        update_id = str(uuid.uuid4())
        
        spark.sql(f"""
        INSERT INTO {database_name}.patent_update_control
        VALUES (
            '{update_id}',
            current_timestamp(),
            '{result_json}'
        )
        """)
        
        print(f"Update logged in control table with ID: {update_id}")
    except Exception as e:
        print(f"Warning: Could not log update in control table: {str(e)}")
    
    return results

def main():
    """Main function that runs as part of the workflow after gold process notebook"""
    
    # Check if we should run based on whether new data was processed in the previous step
    try:
        dbutils.widgets.dropdown("force_update", "false", ["true", "false"], "Force Table Update")
        force_update = dbutils.widgets.get("force_update") == "true"
    except:
        force_update = False
    
    # Check if new data was processed in the previous step
    new_data_processed = check_new_data_processed()
    
    # Skip processing if no new data and not forcing update
    if not new_data_processed and not force_update:
        print("No new data was processed and force_update=false, skipping table registration")
        result = {
            "status": "skipped",
            "message": "No new data was processed and force_update=false"
        }
        dbutils.notebook.exit(json.dumps(result))
        return
    
    # Base path to gold Delta tables
    gold_base_path = "/Volumes/nokia-assginment-catalog/gold"
    database_name = "patent_data"
    
    print(f"Starting patent data registration and upsert process")
    if new_data_processed:
        print("New data was processed in previous step")
    if force_update:
        print("Force update requested")
        
    print(f"Source: {gold_base_path}")
    print(f"Target database: {database_name}")
    
    # Register and upsert all tables
    results = register_and_upsert_all_tables(
        gold_base_path=gold_base_path,
        database_name=database_name
    )
    
    # Print summary
    created_count = sum(1 for r in results.values() if r.get('status') == 'created')
    updated_count = sum(1 for r in results.values() if r.get('status') == 'updated')
    replaced_count = sum(1 for r in results.values() if r.get('status') == 'replaced')
    error_count = sum(1 for r in results.values() if r.get('status') == 'error')
    skipped_count = sum(1 for r in results.values() if r.get('status') == 'skipped')
    warning_count = sum(1 for r in results.values() if r.get('status') == 'warning')
    
    print("\n=== PROCESSING SUMMARY ===")
    print(f"Tables created: {created_count}")
    print(f"Tables updated via upsert: {updated_count}")
    print(f"Tables completely replaced: {replaced_count}")
    print(f"Tables with errors: {error_count}")
    print(f"Tables skipped: {skipped_count}")
    print(f"Tables with warnings: {warning_count}")
    
    # Print details of any errors
    if error_count > 0:
        print("\n=== ERROR DETAILS ===")
        for table, result in results.items():
            if result.get('status') == 'error':
                print(f"{table}: {result.get('message')}")
    
    print("\nProcess completed")
    
    # Return results as JSON
    final_result = {
        "status": "completed",
        "created": created_count,
        "updated": updated_count,
        "replaced": replaced_count,
        "errors": error_count,
        "skipped": skipped_count,
        "warnings": warning_count
    }
    dbutils.notebook.exit(json.dumps(final_result))

main()