# Challenge: Failure Recovery and Checkpoint Management

## Task Description
In this challenge, we need to:
1. Set up checkpointing for Spark Structured Streaming
2. Implement recovery mechanisms for partial failures
3. Properly manage Kafka offsets
4. Ensure exactly-once delivery semantics

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
import time
import os

# Create Spark session
spark = SparkSession.builder \
    .appName("Fault Tolerant Streaming") \
    .master("local[*]") \
    .config("spark.sql.shuffle.partitions", 8) \
    .config("spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.12:3.4.1") \
    .getOrCreate()

spark.sparkContext.setLogLevel("WARN")

## Setup Checkpointing

Checkpointing is essential for fault tolerance in streaming applications. It allows Spark to recover after a failure by storing state information.

In [None]:
# Define checkpoint directory
checkpoint_dir = "/tmp/spark-checkpoints/vpc-security"

# TODO: Create checkpoint directory if it doesn't exist
# Note: In a production environment, this would typically be on HDFS, S3, or another distributed filesystem
# For this notebook, we'll use a local directory

import os
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
    print(f"Created checkpoint directory: {checkpoint_dir}")
else:
    print(f"Checkpoint directory already exists: {checkpoint_dir}")

## Define Schema for Streaming Data

We'll use the same schema as in previous challenges for consistency.

In [None]:
# Define schema for VPN connection events
schema = StructType([
    StructField("user_id", StringType(), True),
    StructField("timestamp", StringType(), True),
    StructField("country", StringType(), True),
    StructField("ip_address", StringType(), True),
    StructField("status", StringType(), True),
    StructField("duration_seconds", IntegerType(), True)
])

## Create a Fault-Tolerant Streaming Query

Let's set up a streaming query with proper checkpointing and error handling.

In [None]:
# TODO: Create a function to build a streaming query with fault tolerance
def create_fault_tolerant_stream(checkpoint_location):
    """Create a streaming query with checkpointing for fault tolerance"""
    
    # Read from Kafka with explicit offsets for recovery
    stream_df = spark \
        .readStream \
        .format("kafka") \
        .option("kafka.bootstrap.servers", "kafka:9092") \
        .option("subscribe", "vpn_connection_events") \
        .option("startingOffsets", "earliest") \
        .option("failOnDataLoss", "false") \
        .load()
    
    # Parse JSON data
    parsed_df = stream_df \
        .selectExpr("CAST(value AS STRING) as json") \
        .select(from_json(col("json"), schema).alias("data")) \
        .select("data.*")
    
    # Add event timestamp for window operations
    timestamped_df = parsed_df \
        .withColumn("event_time", to_timestamp(col("timestamp"))) \
        .withWatermark("event_time", "10 minutes")
    
    # Perform aggregation
    aggregated_df = timestamped_df \
        .groupBy(
            window(col("event_time"), "5 minutes"),
            col("country")
        ) \
        .agg(
            count("*").alias("connection_count"),
            sum(when(col("status") == "success", 1).otherwise(0)).alias("successful_connections"),
            sum(when(col("status") == "failed", 1).otherwise(0)).alias("failed_connections")
        )
    
    # Start the streaming query with checkpoint location
    query = aggregated_df \
        .writeStream \
        .outputMode("update") \
        .format("console") \
        .option("truncate", "false") \
        .option("checkpointLocation", checkpoint_location) \
        .start()
    
    return query

## Exactly-Once Semantics with Idempotent Sink

To achieve exactly-once semantics, we need both checkpointing and an idempotent sink.

In [None]:
# TODO: Implement exactly-once semantics with foreachBatch
def create_exactly_once_stream(checkpoint_location):
    """Create a streaming query with exactly-once semantics"""
    
    # Read from Kafka
    stream_df = spark \
        .readStream \
        .format("kafka") \
        .option("kafka.bootstrap.servers", "kafka:9092") \
        .option("subscribe", "vpn_connection_events") \
        .option("startingOffsets", "earliest") \
        .option("failOnDataLoss", "false") \
        .load()
    
    # Parse JSON data
    parsed_df = stream_df \
        .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING) as json", "topic", "partition", "offset") \
        .select(
            col("key"),
            from_json(col("json"), schema).alias("data"),
            col("topic"),
            col("partition"),
            col("offset")
        ) \
        .select(
            col("topic"),
            col("partition"),
            col("offset"),
            col("data.*")
        )
    
    # Add timestamp
    processed_df = parsed_df \
        .withColumn("event_time", to_timestamp(col("timestamp"))) \
        .withColumn("processing_time", current_timestamp())
    
    # Define a foreachBatch function with idempotent write logic
    def process_batch(batch_df, batch_id):
        if batch_df.isEmpty():
            print(f"Batch {batch_id} is empty, skipping")
            return
        
        # Generate batch metrics
        metrics_df = batch_df \
            .groupBy("country") \
            .agg(
                count("*").alias("connection_count"),
                max("offset").alias("max_offset"),
                min("offset").alias("min_offset")
            )
        
        try:
            # In a real application, this would write to a transactional sink
            # For PostgreSQL, we would use a transaction
            print(f"\nProcessing batch {batch_id} with {batch_df.count()} records")
            
            # Example: Idempotent write to PostgreSQL using a unique constraint
            # This is a simulation:
            print("Metrics by country:")
            metrics_df.show()
            
            # Simulate writing to PostgreSQL with transaction
            print(f"Batch {batch_id} processed successfully")
            print(f"Committed offsets up to {metrics_df.agg({'max_offset': 'max'}).collect()[0][0]}")
            
        except Exception as e:
            # In case of failure, the entire batch will be retried due to checkpointing
            print(f"Error processing batch {batch_id}: {str(e)}")
            raise e
    
    # Start streaming query with exactly-once semantics
    query = processed_df \
        .writeStream \
        .foreachBatch(process_batch) \
        .option("checkpointLocation", checkpoint_location) \
        .start()
    
    return query

## Recovery Simulation

Let's simulate a failure and recovery scenario to test our fault tolerance.

In [None]:
# TODO: Simulate a failure and recovery
def test_failure_recovery():
    checkpoint_subdir = f"{checkpoint_dir}/recovery-test-{int(time.time())}"
    print(f"Starting streaming query with checkpoint at {checkpoint_subdir}")
    
    # Start a query
    query = create_exactly_once_stream(checkpoint_subdir)
    
    # Let it run for a few seconds
    print("Running query for 10 seconds...")
    time.sleep(10)
    
    # Simulate a failure by stopping the query
    print("\nSimulating failure by stopping the query...")
    query.stop()
    
    # Wait a moment
    print("Waiting 5 seconds before recovery...")
    time.sleep(5)
    
    # Restart the query with the same checkpoint dir
    print("\nRestarting query with the same checkpoint location...")
    recovered_query = create_exactly_once_stream(checkpoint_subdir)
    
    # Let it run again to show recovery
    print("Running recovered query for 20 seconds...")
    time.sleep(20)
    
    # Stop the recovered query
    print("\nStopping recovered query...")
    recovered_query.stop()
    
    print("\nFailure recovery test complete")
    print(f"Check {checkpoint_subdir} to see the checkpoint files created")

# Run the test
# Note: Comment this out during development and only run it when the Kafka producer is running
# test_failure_recovery()

## Managing Kafka Offsets for Recovery

Understanding how Kafka offsets are managed is crucial for proper recovery.

In [None]:
# TODO: Explore Kafka offset management
def explain_kafka_offset_management():
    """Explain how Kafka offsets are managed for fault tolerance"""
    
    # Create a sample stream for demonstration
    stream_df = spark \
        .readStream \
        .format("kafka") \
        .option("kafka.bootstrap.servers", "kafka:9092") \
        .option("subscribe", "vpn_connection_events") \
        .option("startingOffsets", "earliest") \
        .option("failOnDataLoss", "false") \
        .load() \
        .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "topic", "partition", "offset")
    
    # Explain the query plan to show how offsets are tracked
    print("Query plan showing how Kafka offsets are tracked:")
    stream_df.explain(True)
    
    print("\nKey Kafka offset options for fault tolerance:")
    print("1. startingOffsets: Where to start reading when no checkpoint exists")
    print("   - 'earliest': Start from the beginning of the topic")
    print("   - 'latest': Start from the end of the topic")
    print("   - '{\"topicA\":{\"0\":23,\"1\":-1},\"topicB\":{\"0\":-1}}': Specific offsets per partition")
    
    print("\n2. failOnDataLoss: How to handle data that is no longer available")
    print("   - 'true': Fail the query if data is lost (e.g., due to retention policies)")
    print("   - 'false': Continue the query even if some data is lost")
    
    print("\n3. checkpointLocation: Where to store offset information")
    print("   - Stores the latest processed offset per partition")
    print("   - Used to resume from where processing left off after a failure")
    
    print("\nWhen a failure occurs and the query restarts:")
    print("1. Spark checks the checkpoint location for the latest committed offsets")
    print("2. It resumes processing from those offsets, ensuring no data loss")
    print("3. With exactly-once semantics, it ensures no duplicate processing")

# Explain Kafka offset management
# explain_kafka_offset_management()

## Examining Checkpoint Files

Let's look at what's stored in the checkpoint directory to better understand the recovery mechanism.

In [None]:
# TODO: Explore checkpoint directory structure
def explore_checkpoint_directory(dir_path):
    """Explore the structure of a checkpoint directory"""
    if not os.path.exists(dir_path):
        print(f"Directory {dir_path} does not exist")
        return
    
    print(f"Contents of {dir_path}:")
    for root, dirs, files in os.walk(dir_path):
        level = root.replace(dir_path, '').count(os.sep)
        indent = ' ' * 4 * level
        print(f"{indent}{os.path.basename(root)}/")
        sub_indent = ' ' * 4 * (level + 1)
        for f in files:
            print(f"{sub_indent}{f}")
    
    # Explain key checkpoint files
    print("\nKey checkpoint files and directories:")
    print("- offsets/: Contains the offsets that have been processed")
    print("- commits/: Tracks successful batch completions")
    print("- metadata/: Stores query metadata")
    print("- sources/: Contains information about the data sources")
    print("- state/: Maintains stateful operator state (for aggregations, windowing, etc.)")

# Explore a sample checkpoint directory
# Note: This requires running a query first to create checkpoint files
# explore_checkpoint_directory(checkpoint_dir)

## Implementing a Robust Streaming Pipeline

Let's combine all our fault tolerance techniques into a production-ready streaming pipeline.

In [None]:
# TODO: Create a robust, fault-tolerant streaming pipeline
def create_robust_pipeline():
    """Create a robust, fault-tolerant streaming pipeline for production use"""
    
    # Define checkpoint location
    robust_checkpoint_dir = f"{checkpoint_dir}/production-pipeline"
    
    # Read from Kafka with fault tolerance options
    raw_stream = spark \
        .readStream \
        .format("kafka") \
        .option("kafka.bootstrap.servers", "kafka:9092") \
        .option("subscribe", "vpn_connection_events") \
        .option("startingOffsets", "earliest") \
        .option("failOnDataLoss", "false") \
        .option("kafkaConsumer.pollTimeoutMs", "5000") \
        .option("fetchOffset.numRetries", "5") \
        .option("fetchOffset.retryIntervalMs", "1000") \
        .load()
    
    # Parse and process data
    parsed_stream = raw_stream \
        .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING) as json", "topic", "partition", "offset") \
        .select(
            col("key"),
            from_json(col("json"), schema).alias("data"),
            col("topic"),
            col("partition"),
            col("offset")
        ) \
        .select(
            col("topic"),
            col("partition"),
            col("offset"),
            col("data.*")
        )
    
    # Add timestamps and watermarks for stateful processing
    timestamped_stream = parsed_stream \
        .withColumn("event_time", to_timestamp(col("timestamp"))) \
        .withColumn("processing_time", current_timestamp()) \
        .withWatermark("event_time", "30 minutes")  # Allow for late data
    
    # Implement error handling and retry logic
    def process_batch_with_retries(batch_df, batch_id):
        if batch_df.isEmpty():
            return
        
        max_retries = 3
        retry_count = 0
        
        while retry_count < max_retries:
            try:
                # Start a transaction (in a real system)
                print(f"Processing batch {batch_id} (attempt {retry_count + 1})")
                
                # Process data
                country_metrics = batch_df \
                    .groupBy("country") \
                    .agg(
                        count("*").alias("connection_count"),
                        sum(when(col("status") == "success", 1).otherwise(0)).alias("successful"),
                        sum(when(col("status") == "failed", 1).otherwise(0)).alias("failed")
                    )
                
                # In a real system, this would write to a database with a transaction
                print("Metrics by country:")
                country_metrics.show(5, truncate=False)
                
                # Commit transaction (in a real system)
                print(f"Successfully processed batch {batch_id}")
                return
                
            except Exception as e:
                retry_count += 1
                if retry_count >= max_retries:
                    print(f"Failed to process batch {batch_id} after {max_retries} attempts. Error: {str(e)}")
                    # In a production system, would log to error tracking system
                    # and possibly write to a dead-letter queue
                    raise e
                else:
                    print(f"Attempt {retry_count} failed. Retrying... Error: {str(e)}")
                    time.sleep(1)  # Backoff before retry
    
    # Start the query with fault tolerance features
    query = timestamped_stream \
        .writeStream \
        .foreachBatch(process_batch_with_retries) \
        .option("checkpointLocation", robust_checkpoint_dir) \
        .trigger(processingTime="10 seconds") \
        .start()
    
    return query

# Start the robust pipeline
# Note: Only run this when ready to test with real data
# robust_query = create_robust_pipeline()

## Best Practices for Fault-Tolerant Streaming

Here are some key best practices for building fault-tolerant streaming applications:

1. **Always use checkpointing**: This is essential for recovery after failures

2. **Configure failure handling**:
   - Set `failOnDataLoss` to `false` for production systems
   - Implement retry logic with backoff for transient errors

3. **Use a reliable storage for checkpoints**:
   - HDFS, S3, or other distributed storage for production
   - Ensure the checkpoint location is accessible by all nodes

4. **Implement idempotent sinks**:
   - Use `foreachBatch` with transaction support
   - Ensure writes are idempotent (e.g., using unique constraints)

5. **Monitor and alert**:
   - Track streaming metrics
   - Set up alerting for failures
   - Keep logs for debugging

6. **Test failure scenarios**:
   - Simulate node failures
   - Test recovery from checkpoint
   - Validate exactly-once semantics

7. **Scale appropriately**:
   - Size your cluster for peak loads plus buffer
   - Use dynamic allocation if available
   - Tune executor memory and CPU

8. **Version your data schemas**:
   - Have a plan for schema evolution
   - Test backward compatibility

9. **Implement dead-letter queues**:
   - Store records that cannot be processed
   - Investigate and reprocess later

10. **Validate end-to-end**:
    - Ensure data consistency
    - Check for duplicate or missing data