In [1]:
"""
This notebook implements an end-to-end ETL pipeline that:
1. Reads streaming data from Kafka
2. Transforms the data with Spark
3. Writes results to PostgreSQL (for analytics) and MinIO (for archival)
"""

'\nThis notebook implements an end-to-end ETL pipeline that:\n1. Reads streaming data from Kafka\n2. Transforms the data with Spark\n3. Writes results to PostgreSQL (for analytics) and MinIO (for archival)\n'

In [39]:
# Import libraries
import psycopg2
import sys
import time
import pandas as pd 
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import (
    col, from_json, current_timestamp, to_date, hour, dayofweek, when, lit, udf
)
from pyspark.sql.types import (
    StructType, StructField, StringType, TimestampType, 
    DoubleType, IntegerType, BooleanType
)
from typing import Any
from time import sleep 
from psycopg2 import sql

In [40]:
# Create Spark session with MinIO/S3 support
spark = (SparkSession.builder
    .appName("SmartMeterETL")
    
    # JAR Configuration - Add Hadoop AWS and related jars
    .config("spark.jars", ",".join([
        "/opt/spark/jars/spark-sql-kafka-0-10_2.12-3.5.0.jar",
        "/opt/spark/jars/kafka-clients-3.5.0.jar",
        "/opt/spark/jars/kafka_2.12-3.5.0.jar",
        "/opt/spark/jars/commons-pool2-2.11.1.jar",
        "/opt/spark/jars/lz4-java-1.8.0.jar",
        "/opt/spark/jars/snappy-java-1.1.10.1.jar",
        "/opt/spark/jars/hadoop-aws-3.3.4.jar",
        "/opt/spark/jars/aws-java-sdk-bundle-1.12.262.jar"
    ]))
    
    # Classpath Configuration
    .config("spark.driver.extraClassPath", "/opt/spark/jars/*")
    .config("spark.executor.extraClassPath", "/opt/spark/jars/*")
    .config("spark.executor.userClassPathFirst", "true")
    
    # MinIO/S3 Configuration
    .config("spark.hadoop.fs.s3a.access.key", "minioadmin")
    .config("spark.hadoop.fs.s3a.secret.key", "minioadmin")
    .config("spark.hadoop.fs.s3a.endpoint", "http://minio:9002")
    .config("spark.hadoop.fs.s3a.path.style.access", "true")
    .config("spark.hadoop.fs.s3a.connection.ssl.enabled", "false")
    .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem")
    
    # Kafka Specific Settings
    .config("spark.sql.streaming.kafka.useDeprecatedOffsetFetching", "false")
    .config("spark.kafka.consumer.cache.enabled", "false")
    .config("spark.streaming.kafka.maxRatePerPartition", "1000")
    
    # JVM Options
    .config("spark.driver.extraJavaOptions",
           "-Dio.netty.tryReflectionSetAccessible=true " +
           "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED " +
           "--add-opens=java.base/java.lang=ALL-UNNAMED " +
           "--add-opens=java.base/java.util=ALL-UNNAMED")
    
    # Performance Tuning
    .config("spark.sql.shuffle.partitions", "4")
    .config("spark.default.parallelism", "4")
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    
    .getOrCreate())

In [41]:
# Test Kafka connectivity before attempting to read
from pyspark.sql.utils import StreamingQueryException

try:
    # Simple test connection
    test_df = spark.read \
        .format("kafka") \
        .option("kafka.bootstrap.servers", "kafka-1:9092") \
        .option("subscribe", "dummy") \
        .option("startingOffsets", "earliest") \
        .load()
    print("✅ Kafka test connection successful")
except Exception as e:
    print(f"❌ Kafka connection failed: {e}")
    raise

✅ Kafka test connection successful


In [42]:
## 1. Extract: Read Streaming Data from Kafka

# Define schema for smart meter data
meter_schema = StructType([
    StructField("meter_id", StringType()),
    StructField("timestamp", TimestampType()),
    StructField("kwh_usage", DoubleType()),
    StructField("voltage", IntegerType()),
    StructField("customer_id", StringType()),
    StructField("region", StringType())
])

kafka_df = (spark
    .readStream
    .format("kafka")
    .option("kafka.bootstrap.servers", "kafka-1:9092,kafka-2:9095")
    .option("subscribe", "smart_meter_data")
    .option("startingOffsets", "latest")
    .option("kafka.security.protocol", "PLAINTEXT")
    .option("failOnDataLoss", "false")
    .option("minPartitions", "1")
    .load())

# Parse JSON data
parsed_df = kafka_df.select(
    from_json(col("value").cast("string"), meter_schema).alias("data")
).select("data.*")

# Debug
debug_query = (parsed_df
    .writeStream
    .format("console")
    .outputMode("append")
    .start())

In [43]:
## 2. Transform: Clean and Enrich Data

# Define validation UDFs
@udf(returnType=BooleanType())
def is_valid_voltage(voltage: int) -> bool:
    """Check if voltage is valid (230V or 240V)."""
    return voltage in [230, 240]

@udf(returnType=BooleanType())
def is_valid_kwh(kwh: float | int) -> bool:
    """Check if kWh usage is within reasonable bounds."""
    return 0 <= kwh <= 20

# Transformation pipeline with proper parentheses
enhanced_df = (
    parsed_df
    # Current transformations
    .withColumn("processing_time", current_timestamp())
    .withColumn("date", to_date(col("timestamp")))
    .withColumn("hour_of_day", hour(col("timestamp")))
    .withColumn("day_of_week", dayofweek(col("timestamp")))
    .withColumn("cost", 
        when(col("region") == "Auckland", col("kwh_usage") * 0.25)
        .when(col("region") == "Wellington", col("kwh_usage") * 0.23)
        .otherwise(col("kwh_usage") * 0.20))
    .withColumn("is_peak", 
        ((col("hour_of_day") >= 17) & (col("hour_of_day") <= 21)))
    
    # Enhanced data quality checks
    .withColumn("is_weekend", col("day_of_week").isin([1, 7]))
    .withColumn("is_valid_voltage", is_valid_voltage(col("voltage")))
    .withColumn("is_valid_kwh", is_valid_kwh(col("kwh_usage")))
    .withColumn("data_quality_flag",
        when(col("is_valid_voltage") & col("is_valid_kwh"), "VALID")
        .otherwise("INVALID"))
    
    # Improved null handling
    .filter(
        col("meter_id").isNotNull() & 
        col("customer_id").isNotNull() &
        col("timestamp").isNotNull()
    )
    
    # Add record source
    .withColumn("source_system", lit("kafka_stream"))
)

In [44]:
## 3. Load: Write to Postgres

# Function to write batch to PostgreSQL
def write_to_postgres(batch_df: DataFrame, batch_id: Any) -> None:
    if batch_df.isEmpty():
        print(f"Skipping empty batch {batch_id}")
        return

    batch_df.persist()
    conn = None
    
    try:
        pdf = batch_df.select([
            "meter_id", "timestamp", "kwh_usage", "voltage",
            "customer_id", "region", "hour_of_day", "cost",
            "is_peak", "is_weekend", "processing_time",
            "date", "data_quality_flag", "source_system"
        ]).toPandas()

        # Convert timestamp columns
        pdf['timestamp'] = pd.to_datetime(pdf['timestamp'])
        pdf['processing_time'] = pd.to_datetime(pdf['processing_time'])
        pdf['date'] = pd.to_datetime(pdf['date']).dt.date

        with psycopg2.connect(
            host="postgres",
            dbname="postgres",
            user="postgres",
            password="postgres",
            connect_timeout=5
        ) as conn:
            with conn.cursor() as cur:
                # Ensure all customers exist
                customers = list(pdf['customer_id'].unique())
                insert_customer_sql = """
                    INSERT INTO dim_customer (customer_id)
                    VALUES (%s) ON CONFLICT (customer_id) DO NOTHING
                """
                for cust in customers:
                    cur.execute(insert_customer_sql, (cust,))
                
                # Ensure all meters exist
                meters = list(pdf['meter_id'].unique())
                insert_meter_sql = """
                    INSERT INTO dim_meter (meter_id)
                    VALUES (%s) ON CONFLICT (meter_id) DO NOTHING
                """
                for meter in meters:
                    cur.execute(insert_meter_sql, (meter,))
                
                conn.commit()

                # Insert meter readings
                insert_sql = """
                    INSERT INTO fact_smart_meter_readings (
                        meter_id, timestamp, kwh_usage, voltage,
                        customer_id, region, hour_of_day, cost,
                        is_peak, is_weekend, processing_time,
                        date, data_quality_flag, source_system
                    ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                    ON CONFLICT ON CONSTRAINT unique_meter_timestamp 
                    DO UPDATE SET
                        kwh_usage = EXCLUDED.kwh_usage,
                        voltage = EXCLUDED.voltage,
                        cost = EXCLUDED.cost,
                        data_quality_flag = EXCLUDED.data_quality_flag,
                        processing_time = EXCLUDED.processing_time
                    WHERE fact_smart_meter_readings.data_quality_flag = 'INVALID'
                """
                args = [tuple(row) for row in pdf.itertuples(index=False)]
                
                for i in range(0, len(args), 100):
                    cur.executemany(insert_sql, args[i:i+100])
                    conn.commit()

                print(f"Processed {len(pdf)} records (Batch {batch_id})")

    except Exception as e:
        print(f"Error in batch {batch_id}: {str(e)}")
        if conn:
            conn.rollback()
    finally:
        batch_df.unpersist()
        if conn:
            conn.close()

In [45]:
## 4. Load: Write to Minio/S3

# Function to write batch to MinIO
def write_to_minio(batch_df: DataFrame, batch_id: Any) -> None:
    """Write batch data to MinIO using pre-configured Spark session settings."""
    try:
        if batch_df.isEmpty():
            print(f"Skipping empty batch {batch_id}")
            return

        # Test connection with unique path that includes batch_id
        test_path = f"s3a://default/.spark_test_{batch_id}_{int(time.time())}"

        try:
            # Test connection by writing and immediately deleting a test file
            test_rdd = spark.sparkContext.parallelize(["connection_test"])
            test_rdd.saveAsTextFile(test_path)
            
            # Delete the test file using Hadoop FS API
            hadoop_conf = spark._jsc.hadoopConfiguration()
            fs = spark._jvm.org.apache.hadoop.fs.FileSystem.get(hadoop_conf)
            fs.delete(spark._jvm.org.apache.hadoop.fs.Path(test_path), True)
            
            print("MinIO connection test successful")
        except Exception as test_e:
            print(f"MinIO connection test warning: {str(test_e)}")
            # Continue processing despite test failure
        
        # Use consistent path naming with date partitioning
        output_path = f"s3a://default/smart_meter/raw/batch_id={batch_id}/"
        
        try:
            (batch_df.write
                .format("parquet")
                .mode("append")
                .option("compression", "snappy")
                .save(output_path))
            
            print(f"Successfully wrote {batch_df.count()} records to {output_path}")
        except Exception as write_e:
            print(f"Failed to write batch {batch_id}: {str(write_e)}")

    except Exception as e:
        print(f"Unexpected error in batch {batch_id}: {str(e)}")

In [47]:
## 5. Execute the Streaming Pipeline

def run_streaming():
    try:
        print("Starting streaming queries...")
        
        # Start PostgreSQL writer
        pg_query = (enhanced_df.writeStream
            .foreachBatch(write_to_postgres)
            .option("checkpointLocation", "/tmp/checkpoints/postgres")
            .option("continueOnError", "true")
            .start())
        
        # Start MinIO writer
        minio_query = (enhanced_df.writeStream
            .foreachBatch(write_to_minio)
            .option("checkpointLocation", "/tmp/checkpoints/minio")
            .option("continueOnError", "true")
            .start())

        # Handle each query separately
        while True:
            pg_status = pg_query.status
            minio_status = minio_query.status
            
            print(f"\nPostgreSQL Status: {pg_status['message']}")
            print(f"MinIO Status: {minio_status['message']}")
            
            # Check for errors without immediately failing
            if pg_ex := pg_query.exception():
                print(f"PostgreSQL query error: {str(pg_ex)}")
            if minio_ex := minio_query.exception():
                print(f"MinIO query error: {str(minio_ex)}")
                
            sleep(5)
            
    except KeyboardInterrupt:
        print("\nUser requested shutdown...")
    except Exception as e:
        print(f"\nCRITICAL ERROR: {str(e)}", file=sys.stderr)
    finally:
        print("\nShutting down streams...")
        for name, q in [("PostgreSQL", pg_query), ("MinIO", minio_query)]:
            if q and q.isActive:
                print(f"Stopping {name} query...")
                try:
                    q.stop()
                except Exception as e:
                    print(f"Error stopping {name} query: {str(e)}")
        print("All streams stopped")
    try:
        with psycopg2.connect(
            host="postgres",
            dbname="postgres",
            user="postgres",
            password="postgres",
            connect_timeout=5
        ) as conn:
            with conn.cursor() as cur:
                # Create dim_customer if not exists
                cur.execute("""
                    CREATE TABLE IF NOT EXISTS dim_customer (
                        customer_id VARCHAR(50) PRIMARY KEY,
                        created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
                    )
                """)
                # Create dim_meter if not exists
                cur.execute("""
                    CREATE TABLE IF NOT EXISTS dim_meter (
                        meter_id VARCHAR(50) PRIMARY KEY,
                        created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
                    )
                """)
                conn.commit()
    except Exception as e:
        print(f"Error initializing tables: {str(e)}")
        raise

In [48]:
## 6. Init tables
def initialize_tables():
    try:
        with psycopg2.connect(
            host="postgres",
            dbname="postgres",
            user="postgres",
            password="postgres",
            connect_timeout=5
        ) as conn:
            with conn.cursor() as cur:
                # Create dim_customer if not exists
                cur.execute("""
                    CREATE TABLE IF NOT EXISTS dim_customer (
                        customer_id VARCHAR(50) PRIMARY KEY,
                        created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
                    )
                """)
                # Create dim_meter if not exists
                cur.execute("""
                    CREATE TABLE IF NOT EXISTS dim_meter (
                        meter_id VARCHAR(50) PRIMARY KEY,
                        created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
                    )
                """)
                conn.commit()
    except Exception as e:
        print(f"Error initializing tables: {str(e)}")
        raise

In [50]:
## 7 Init tables
initialize_tables()

In [51]:
## 8. Run ETL
run_streaming()

Starting streaming queries...

PostgreSQL Status: Processing new data
MinIO Status: Initializing sources
MinIO connection test failed: RDD.saveAsTextFile() got an unexpected keyword argument 'mode'
Failed to write batch 4: RDD.saveAsTextFile() got an unexpected keyword argument 'mode'
MinIO connection test failed: An error occurred while calling o1211.saveAsTextFile.
: org.apache.hadoop.mapred.FileAlreadyExistsException: Output directory s3a://default/.spark_test_file already exists
	at org.apache.hadoop.mapred.FileOutputFormat.checkOutputSpecs(FileOutputFormat.java:131)
	at org.apache.spark.internal.io.HadoopMapRedWriteConfigUtil.assertConf(SparkHadoopWriter.scala:299)
	at org.apache.spark.internal.io.SparkHadoopWriter$.write(SparkHadoopWriter.scala:71)
	at org.apache.spark.rdd.PairRDDFunctions.$anonfun$saveAsHadoopDataset$1(PairRDDFunctions.scala:1091)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDO