In [0]:
%pip install psycopg2-binary
%pip install kafka-python
dbutils.library.restartPython()

In [0]:
# Configuration - set these values directly


JDBC_URL = dbutils.secrets.get(scope="rtm_demo", key="db_url")
JDBC_USER = dbutils.secrets.get(scope="rtm_demo", key="db_user")
JDBC_PASSWORD = dbutils.secrets.get(scope="rtm_demo", key="db_password")
JDBC_DRIVER = "org.postgresql.Driver"
TABLE_NAME = "transactions_stream"
kafka_user = dbutils.secrets.get(scope="rtm_demo", key="kafka_user")
kafka_password = dbutils.secrets.get(scope="rtm_demo", key="kafka_password")
kafka_topic = dbutils.secrets.get(scope="rtm_demo", key="topic")
jaas_config = (
    'kafkashaded.org.apache.kafka.common.security.scram.ScramLoginModule required '
    f'username="{kafka_user}" '
    f'password="{kafka_password}";'
)


In [0]:
raw_kafka_df = (
    spark.readStream
    .format("kafka")
    .option("kafka.bootstrap.servers", dbutils.secrets.get(scope="rtm_demo", key="kafka.bootstrap.servers"))
    .option("subscribe", kafka_topic)
    .option("startingOffsets", "latest")          # or "earliest"
    .option("kafka.security.protocol", "SASL_SSL")
    .option("kafka.sasl.mechanism", "SCRAM-SHA-256")
    .option("kafka.sasl.jaas.config", jaas_config)
    .load()
)

# Decode key/value to strings
events_df = (
    raw_kafka_df
    .selectExpr(
        "CAST(key AS STRING) AS customerID",
        "CAST(value AS STRING) AS value_json",
        "timestamp"
    )
)


In [0]:
# Read the Deduplicate widget
deduplicate = dbutils.widgets.get("Deduplicate")

# Set shuffle partitions to match Kafka topic partition count
# NOTE: This should match the kafka_partitions value used in the producer notebook
# The producer creates the topic with 8 partitions
num_partitions = 8

print(f"Setting shuffle partitions to {num_partitions} (matching Kafka topic partition count)")

if deduplicate == "Yes":
    print(f"Deduplication enabled: Using {num_partitions} shuffle partitions")
else:
    print(f"Deduplication disabled: Using {num_partitions} shuffle partitions")

spark.conf.set("spark.sql.shuffle.partitions", num_partitions)

print(f"\nCurrent shuffle partitions: {spark.conf.get('spark.sql.shuffle.partitions')}")

In [0]:
import psycopg2
from psycopg2.extras import execute_batch
from datetime import datetime
from urllib.parse import urlparse
import time


def parse_pg_url(jdbc_url, jdbc_user, jdbc_password):
    url = jdbc_url.replace("jdbc:", "")
    parsed = urlparse(url)
    return {
        "host": parsed.hostname,
        "port": parsed.port or 5432,
        "dbname": parsed.path.lstrip("/"),
        "user": jdbc_user,
        "password": jdbc_password,
    }


def make_pg_buffered_writer(
    jdbc_url,
    jdbc_user,
    jdbc_password,
    jdbc_driver,
    table_name,
    pool_size,
    max_batch_size=100,   # flush when >= this many rows (0 = simple mode)
    flush_secs=2.0,       # flush if last flush older than this (0 = simple mode)
):
    conn_kwargs = parse_pg_url(jdbc_url, jdbc_user, jdbc_password)
    
    # Determine mode based on parameters
    use_buffered_mode = max_batch_size > 0 and flush_secs > 0
    
    if use_buffered_mode:
        print(f"[JDBC] Using BUFFERED mode (max_batch_size={max_batch_size}, flush_secs={flush_secs})")
    else:
        print(f"[JDBC] Using SIMPLE mode (immediate writes)")

    class PgForeachWriter:
        def open(self, partition_id, epoch_id):
            try:
                print(f"[JDBC] open partition={partition_id}, epoch={epoch_id}")
                self.conn = psycopg2.connect(**conn_kwargs)
                self.conn.autocommit = False
                self.cursor = self.conn.cursor()
                
                # Schema will be set on first process() call
                self.column_names = None
                self.sql = None
                
                # Buffering state (only used in buffered mode)
                if use_buffered_mode:
                    self.buffer = []
                    self.last_flush_ts = time.time()
                
                return True
            except Exception as e:
                print(f"[JDBC] Error opening connection: {e}")
                self.conn = None
                self.cursor = None
                return False

        def _initialize_schema(self, row):
            """Initialize SQL statement based on row schema (called once per partition)"""
            if self.column_names is not None:
                return  # Already initialized
            
            # Get column names from the row
            self.column_names = row.asDict().keys()
            
            # Build dynamic INSERT statement
            columns_str = ", ".join(self.column_names)
            placeholders = ", ".join(["%s" for _ in self.column_names])
            
            self.sql = f"INSERT INTO {table_name} ({columns_str}) VALUES ({placeholders})"
            
            print(f"[JDBC] Initialized schema with columns: {list(self.column_names)}")
            print(f"[JDBC] SQL: {self.sql}")

        def _flush_if_needed(self, force=False):
            """Flush buffer if needed (buffered mode only)"""
            if not use_buffered_mode:
                return  # Not applicable in simple mode
            
            now = time.time()
            need_by_size = len(self.buffer) >= max_batch_size
            need_by_time = (now - self.last_flush_ts) >= flush_secs

            if not (force or need_by_size or need_by_time):
                return

            if not self.buffer:
                self.last_flush_ts = now
                return

            batch = list(self.buffer)
            self.buffer.clear()
            self.last_flush_ts = now

            try:
                print(f"[JDBC] flushing {len(batch)} rows")
                execute_batch(self.cursor, self.sql, batch)
                self.conn.commit()
            except Exception as e:
                print(f"[JDBC] Error flushing batch: {e}")
                # Put rows back so they're not lost
                self.buffer[:0] = batch
                self.conn.rollback()
                raise

        def _write_immediate(self, values):
            """Write single row immediately (simple mode only)"""
            try:
                print(f"[JDBC] writing row immediately")
                self.cursor.execute(self.sql, values)
                self.conn.commit()
            except Exception as e:
                print(f"[JDBC] Error writing row: {e}")
                self.conn.rollback()
                raise

        def process(self, row):
            # Initialize schema on first row
            self._initialize_schema(row)
            
            print(f"[JDBC] process row: {row}")

            # Extract values dynamically based on column order
            row_dict = row.asDict()
            values = []
            
            for col_name in self.column_names:
                val = row_dict[col_name]
                
                # Convert datetime objects to ISO format strings
                if isinstance(val, datetime):
                    val = val.isoformat(sep=" ", timespec="microseconds")
                
                values.append(val)
            
            if use_buffered_mode:
                # Buffered mode: add to buffer and flush if needed
                self.buffer.append(tuple(values))
                self._flush_if_needed(force=False)
            else:
                # Simple mode: write immediately
                self._write_immediate(tuple(values))

        def close(self, error):
            print(f"[JDBC] close, error={error}")
            try:
                if error is None and use_buffered_mode:
                    # Flush any remaining buffered rows
                    self._flush_if_needed(force=True)
            except Exception as e:
                print(f"[JDBC] Error flushing on close: {e}")
            try:
                if getattr(self, "cursor", None):
                    self.cursor.close()
            except Exception:
                pass
            try:
                if getattr(self, "conn", None):
                    self.conn.close()
            except Exception:
                pass

    return PgForeachWriter()

In [0]:
# Get widget parameters
max_batch_size_param = int(dbutils.widgets.get("max_batch_size"))
flush_secs_param = float(dbutils.widgets.get("flush_secs"))

print(f"Buffered writer configuration:")
print(f"  max_batch_size: {max_batch_size_param}")
print(f"  flush_secs: {flush_secs_param}")
print()

jdbc_writer = make_pg_buffered_writer(
    jdbc_url=JDBC_URL,
    jdbc_user=JDBC_USER,
    jdbc_password=JDBC_PASSWORD,
    jdbc_driver=JDBC_DRIVER,
    table_name=TABLE_NAME,
    pool_size=None,
    max_batch_size=max_batch_size_param,
    flush_secs=flush_secs_param
)

In [0]:
import psycopg2
from urllib.parse import urlparse
from pyspark.sql.types import StringType, TimestampType, IntegerType, LongType, DoubleType, FloatType, BooleanType, DateType

# Parse JDBC URL to get connection parameters
def parse_pg_url(jdbc_url, jdbc_user, jdbc_password):
    url = jdbc_url.replace("jdbc:", "")
    parsed = urlparse(url)
    return {
        "host": parsed.hostname,
        "port": parsed.port or 5432,
        "dbname": parsed.path.lstrip("/"),
        "user": jdbc_user,
        "password": jdbc_password,
    }

# Map Spark types to PostgreSQL types
def spark_to_postgres_type(spark_type):
    if isinstance(spark_type, StringType):
        return "TEXT"
    elif isinstance(spark_type, TimestampType):
        return "TIMESTAMP"
    elif isinstance(spark_type, DateType):
        return "DATE"
    elif isinstance(spark_type, IntegerType):
        return "INTEGER"
    elif isinstance(spark_type, LongType):
        return "BIGINT"
    elif isinstance(spark_type, DoubleType) or isinstance(spark_type, FloatType):
        return "DOUBLE PRECISION"
    elif isinstance(spark_type, BooleanType):
        return "BOOLEAN"
    else:
        return "TEXT"  # Default fallback

conn_kwargs = parse_pg_url(JDBC_URL, JDBC_USER, JDBC_PASSWORD)

print(f"Connecting to PostgreSQL database: {conn_kwargs['dbname']}")
print(f"Target table: {TABLE_NAME}")
print()

# Get schema from the streaming DataFrame
print("Inspecting streaming DataFrame schema...")
df_schema = events_df.schema

print("DataFrame columns:")
for field in df_schema.fields:
    print(f"  - {field.name}: {field.dataType}")
print()

# Build CREATE TABLE statement dynamically
column_definitions = []
for field in df_schema.fields:
    col_name = field.name
    pg_type = spark_to_postgres_type(field.dataType)
    nullable = "" if field.nullable else "NOT NULL"
    column_definitions.append(f"{col_name} {pg_type} {nullable}".strip())

create_sql = f"""
CREATE TABLE {TABLE_NAME} (
    {',\n    '.join(column_definitions)}
)
"""

# Connect and create/recreate table
try:
    conn = psycopg2.connect(**conn_kwargs)
    conn.autocommit = True
    cursor = conn.cursor()
    
    # Drop table if exists
    drop_sql = f"DROP TABLE IF EXISTS {TABLE_NAME}"
    print(f"Executing: {drop_sql}")
    cursor.execute(drop_sql)
    print("  ✓ Table dropped (if existed)")
    print()
    
    # Create table with schema matching DataFrame
    print(f"Executing: {create_sql}")
    cursor.execute(create_sql)
    print("  ✓ Table created successfully")
    print()
    
    # Verify table exists
    cursor.execute(f"""
        SELECT column_name, data_type 
        FROM information_schema.columns 
        WHERE table_name = '{TABLE_NAME}'
        ORDER BY ordinal_position
    """)
    columns = cursor.fetchall()
    
    print("PostgreSQL table schema:")
    for col_name, col_type in columns:
        print(f"  - {col_name}: {col_type}")
    
    cursor.close()
    conn.close()
    
    print()
    print("✓ Table ready for streaming writes")
    
except Exception as e:
    print(f"Error creating table: {e}")
    raise

In [0]:
import pyspark.sql.functions as F

# Re-read widget parameters to ensure we use the latest values
print("Reading widget parameters...")
deduplicate = dbutils.widgets.get("Deduplicate")
max_batch_size_param = int(dbutils.widgets.get("max_batch_size"))
flush_secs_param = float(dbutils.widgets.get("flush_secs"))

print(f"Widget values:")
print(f"  Deduplicate: {deduplicate}")
print(f"  max_batch_size: {max_batch_size_param}")
print(f"  flush_secs: {flush_secs_param}")
print()

# Recreate the JDBC writer with current widget values
print("Creating JDBC writer with current parameters...")
jdbc_writer = make_pg_buffered_writer(
    jdbc_url=JDBC_URL,
    jdbc_user=JDBC_USER,
    jdbc_password=JDBC_PASSWORD,
    jdbc_driver=JDBC_DRIVER,
    table_name=TABLE_NAME,
    pool_size=None,
    max_batch_size=max_batch_size_param,
    flush_secs=flush_secs_param
)
print("  ✓ JDBC writer created")
print()

df_for_write = events_df  # Start with the base streaming DataFrame

if deduplicate == "Yes":
    checkpoint_path = "/tmp/foreach_jdbc_checkpoint_deduplicated"
    
    print("Executing streaming query WITH deduplication...")
    print("  - Parsing JSON to extract amount and daily_average")
    print("  - Adding 10-minute watermark")
    print("  - Dropping duplicates based on customerID, amount, and daily_average")
    print()
    
    # Delete checkpoint before starting
    print(f"Deleting checkpoint location: {checkpoint_path}")
    try:
        dbutils.fs.rm(checkpoint_path, recurse=True)
        print("  Checkpoint deleted successfully")
    except Exception as e:
        print(f"  Checkpoint deletion skipped (may not exist): {e}")
    print()
    
    query = (
        df_for_write    
        .withColumn("parsed_json", F.from_json("value_json", "amount STRING, daily_average STRING"))
        .withColumn("amount", F.col("parsed_json.amount"))
        .withColumn("daily_average", F.col("parsed_json.daily_average"))
        .withWatermark("timestamp", "10 minutes")
        .dropDuplicates(["customerID", "amount", "daily_average"])
        .select("customerID", "value_json", "timestamp")
        .writeStream
        .foreach(jdbc_writer)
        .outputMode("update")
        .queryName("jdbc_sink_writer_deduplicated")
        .trigger(realTime="5 minutes")
        .option("checkpointLocation", checkpoint_path)
        .start()
    )
    
    print(f"Deduplicated streaming query started: {query.id}")
    
else:
    checkpoint_path = "/tmp/foreach_jdbc_checkpoint_standard"
    
    print("Executing streaming query WITHOUT deduplication...")
    print()
    
    # Delete checkpoint before starting
    print(f"Deleting checkpoint location: {checkpoint_path}")
    try:
        dbutils.fs.rm(checkpoint_path, recurse=True)
        print("  Checkpoint deleted successfully")
    except Exception as e:
        print(f"  Checkpoint deletion skipped (may not exist): {e}")
    print()
    
    query = (
        df_for_write
        .writeStream
        .foreach(jdbc_writer)
        .outputMode("update")
        .queryName("jdbc_sink_writer")
        .trigger(realTime="5 minutes")
        .option("checkpointLocation", checkpoint_path)
        .start()
    )
    
    print(f"Standard streaming query started: {query.id}")