In [0]:
%pip install jaydebeapi
%pip install SQLAlchemy
dbutils.library.restartPython()
#You will also need to install postgres jdbc jar and make it available in the classpath
#The version used in this notebook is postgresql-42.7.1

In [0]:
# Configuration - set these values directly


JDBC_URL = dbutils.secrets.get(scope="rtm_demo", key="db_url")
JDBC_USER = dbutils.secrets.get(scope="rtm_demo", key="db_user")
JDBC_PASSWORD = dbutils.secrets.get(scope="rtm_demo", key="db_password")
JDBC_DRIVER = "org.postgresql.Driver"
JDBC_JAR_PATH = dbutils.secrets.get(scope="rtm_demo", key="jdbc_driver_jar_path")
TABLE_NAME = "transactions_stream"
kafka_user = dbutils.secrets.get(scope="rtm_demo", key="kafka_user")
kafka_password = dbutils.secrets.get(scope="rtm_demo", key="kafka_password")
kafka_topic = dbutils.secrets.get(scope="rtm_demo", key="topic")
jaas_config = (
    'kafkashaded.org.apache.kafka.common.security.scram.ScramLoginModule required '
    f'username="{kafka_user}" '
    f'password="{kafka_password}";'
)


In [0]:
raw_kafka_df = (
    spark.readStream
    .format("kafka")
    .option("kafka.bootstrap.servers", dbutils.secrets.get(scope="rtm_demo", key="kafka.bootstrap.servers"))
    .option("subscribe", kafka_topic)
    .option("startingOffsets", "latest")          # or "earliest"
    .option("kafka.security.protocol", "SASL_SSL")
    .option("kafka.sasl.mechanism", "SCRAM-SHA-256")
    .option("kafka.sasl.jaas.config", jaas_config)
    .load()
)

# Decode key/value to strings
events_df = (
    raw_kafka_df
    .selectExpr(
        "CAST(key AS STRING) AS customerID",
        "CAST(value AS STRING) AS value_json",
        "timestamp"
    )
)


In [0]:
spark.conf.set("spark.sql.shuffle.partitions", 4)

In [0]:
import psycopg2
from psycopg2 import pool as pg_pool
from datetime import datetime
from urllib.parse import urlparse


def parse_pg_url(jdbc_url, jdbc_user, jdbc_password):
    """
    Parse jdbc_url into connection kwargs for psycopg2.
    Supports jdbc:postgresql://host:port/db or postgres://host:port/db.
    """
    url = jdbc_url.replace("jdbc:", "")  # strip jdbc: if present
    parsed = urlparse(url)

    return {
        "host": parsed.hostname,
        "port": parsed.port or 5432,
        "dbname": parsed.path.lstrip("/"),
        "user": jdbc_user,
        "password": jdbc_password,
    }


def make_pg_writer(jdbc_url, jdbc_user, jdbc_password, jdbc_driver,
                   jdbc_jar_path, table_name, pool_size):
    conn_kwargs = parse_pg_url(jdbc_url, jdbc_user, jdbc_password)

    # ========= 1) SIMPLE WRITER =========
    class SimplePgForeachWriter:
        def open(self, partition_id, epoch_id):
            try:
                self.conn = psycopg2.connect(**conn_kwargs)
                self.conn.autocommit = True  # match previous autocommit behavior
                self.cursor = self.conn.cursor()
                return True
            except Exception as e:
                print(f"Error opening psycopg2 connection: {e}")
                self.conn = None
                self.cursor = None
                return False

        def process(self, row):
            sql = f"""
                INSERT INTO {table_name} (customer_id, value_json, event_ts)
                VALUES (%s, %s::jsonb, %s::timestamp)
            """

            customer_id = row["customerID"]
            value_json = row["value_json"]
            ts = row["timestamp"]

            if isinstance(ts, datetime):
                ts_str = ts.isoformat(sep=" ", timespec="microseconds")
            else:
                ts_str = str(ts)

            self.cursor.execute(
                sql,
                (
                    customer_id,
                    value_json,
                    ts_str,
                )
            )

        def close(self, error):
            try:
                if getattr(self, "cursor", None):
                    self.cursor.close()
            except Exception:
                pass
            try:
                if getattr(self, "conn", None):
                    self.conn.close()
            except Exception:
                pass
    return SimplePgForeachWriter()



In [0]:
import psycopg2
from psycopg2.extras import execute_batch
from datetime import datetime
from urllib.parse import urlparse
import time


def parse_pg_url(jdbc_url, jdbc_user, jdbc_password):
    url = jdbc_url.replace("jdbc:", "")
    parsed = urlparse(url)
    return {
        "host": parsed.hostname,
        "port": parsed.port or 5432,
        "dbname": parsed.path.lstrip("/"),
        "user": jdbc_user,
        "password": jdbc_password,
    }


def make_pg_buffered_writer(
    jdbc_url,
    jdbc_user,
    jdbc_password,
    jdbc_driver,
    jdbc_jar_path,
    table_name,
    pool_size,
    max_batch_size=100,   # flush when >= this many rows
    flush_secs=2.0,       # flush if last flush older than this
):
    conn_kwargs = parse_pg_url(jdbc_url, jdbc_user, jdbc_password)

    class BufferedPgForeachWriter:
        def open(self, partition_id, epoch_id):
            try:
                print(f"[JDBC] open partition={partition_id}, epoch={epoch_id}")
                self.conn = psycopg2.connect(**conn_kwargs)
                self.conn.autocommit = False
                self.cursor = self.conn.cursor()
                self.buffer = []
                self.last_flush_ts = time.time()
                self.sql = (
                    f"INSERT INTO {table_name} (customer_id, value_json, event_ts) "
                    f"VALUES (%s, %s::jsonb, %s::timestamp)"
                )
                return True
            except Exception as e:
                print(f"[JDBC] Error opening connection: {e}")
                self.conn = None
                self.cursor = None
                return False

        def _flush_if_needed(self, force=False):
            now = time.time()
            need_by_size = len(self.buffer) >= max_batch_size
            need_by_time = (now - self.last_flush_ts) >= flush_secs

            if not (force or need_by_size or need_by_time):
                return

            if not self.buffer:
                self.last_flush_ts = now
                return

            batch = list(self.buffer)
            self.buffer.clear()
            self.last_flush_ts = now

            try:
                print(f"[JDBC] flushing {len(batch)} rows")
                execute_batch(self.cursor, self.sql, batch)
                self.conn.commit()
            except Exception as e:
                print(f"[JDBC] Error flushing batch: {e}")
                # Put rows back so theyâ€™re not lost
                self.buffer[:0] = batch
                self.conn.rollback()
                raise

        def process(self, row):
            print(f"[JDBC] process row: {row}")

            customer_id = row["customerID"]
            value_json = row["value_json"]
            ts = row["timestamp"]

            if isinstance(ts, datetime):
                ts_str = ts.isoformat(sep=" ", timespec="microseconds")
            else:
                ts_str = str(ts)

            self.buffer.append((customer_id, value_json, ts_str))
            self._flush_if_needed(force=False)

        def close(self, error):
            print(f"[JDBC] close, error={error}")
            try:
                if error is None:
                    self._flush_if_needed(force=True)
            except Exception as e:
                print(f"[JDBC] Error flushing on close: {e}")
            try:
                if getattr(self, "cursor", None):
                    self.cursor.close()
            except Exception:
                pass
            try:
                if getattr(self, "conn", None):
                    self.conn.close()
            except Exception:
                pass

    return BufferedPgForeachWriter()


In [0]:
# Get widget parameters
max_batch_size_param = int(dbutils.widgets.get("max_batch_size"))
flush_secs_param = float(dbutils.widgets.get("flush_secs"))

print(f"Buffered writer configuration:")
print(f"  max_batch_size: {max_batch_size_param}")
print(f"  flush_secs: {flush_secs_param}")
print()

jdbc_writer = make_pg_buffered_writer(
    jdbc_url=JDBC_URL,
    jdbc_user=JDBC_USER,
    jdbc_password=JDBC_PASSWORD,
    jdbc_driver=JDBC_DRIVER,
    jdbc_jar_path=JDBC_JAR_PATH,
    table_name=TABLE_NAME,
    pool_size=None,
    max_batch_size=max_batch_size_param,
    flush_secs=flush_secs_param
)

In [0]:
jdbc_writer = make_pg_writer(
    jdbc_url=JDBC_URL,
    jdbc_user=JDBC_USER,
    jdbc_password=JDBC_PASSWORD,
    jdbc_driver=JDBC_DRIVER,
    jdbc_jar_path=JDBC_JAR_PATH,
    table_name=TABLE_NAME,
    pool_size=POOL_SIZE if use_sqlalchemy_pool == "yes" else None
)

In [0]:
import pyspark.sql.functions as F

# Get the deduplication widget value
deduplicate = dbutils.widgets.get("Deduplicate")

print(f"Deduplication setting: {deduplicate}")
print()

df_for_write = events_df  # Start with the base streaming DataFrame

if deduplicate == "Yes":
    print("Executing streaming query WITH deduplication...")
    print("  - Parsing JSON to extract amount and daily_average")
    print("  - Adding 10-minute watermark")
    print("  - Dropping duplicates based on customerID, amount, and daily_average")
    print()
    
    query = (
        df_for_write    
        .withColumn("parsed_json", F.from_json("value_json", "amount STRING, daily_average STRING"))
        .withColumn("amount", F.col("parsed_json.amount"))
        .withColumn("daily_average", F.col("parsed_json.daily_average"))
        .withWatermark("timestamp", "10 minutes")
        .dropDuplicates(["customerID", "amount", "daily_average"])
        .select("customerID", "value_json", "timestamp")
        .writeStream
        .foreach(jdbc_writer)
        .outputMode("update")
        .queryName("jdbc_sink_writer_deduplicated")
        .trigger(realTime="15 seconds")
        .option("checkpointLocation", "/tmp/foreach_jdbc_checkpoint_deduplicated_8")
        .start()
    )
    
    print(f"Deduplicated streaming query started: {query.id}")
    
else:
    print("Executing streaming query WITHOUT deduplication...")
    print()
    
    query = (
        df_for_write.writeStream
        .foreach(jdbc_writer)
        .outputMode("update")
        .queryName("jdbc_sink_writer")
        .trigger(realTime="15 seconds")
        .option("checkpointLocation", "/tmp/foreach_jdbc_checkpoint_standard_8")
        .start()
    )
    
    print(f"Standard streaming query started: {query.id}")