In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("KafkaSparkStreamingNotebook") \
    .config("spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.3") \
    .getOrCreate()
spark

In [None]:
from pyspark.sql.functions import expr

# Read from Kafka
df_raw = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "broker:29094") \
    .option("subscribe", "flight") \
    .option("startingOffsets", "earliest") \
    .load()

# Convert binary key/value to strings
df = df_raw.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")

# Parse JSON if your Kafka messages are JSON
df_parsed = df.selectExpr("CAST(value AS STRING) as json") \
    .selectExpr("from_json(json, 'id INT, amount DOUBLE, type STRING') as data") \
    .select("data.*")

# Simple aggregation
agg = df_parsed.groupBy("type").sum("amount")

# Output to console (for debugging)
query = agg.writeStream \
    .outputMode("complete") \
    .format("console") \
    .start()

query.awaitTermination()


In [1]:
import json
import time
from pyspark.sql import SparkSession
from pyspark.sql.functions import from_json, col, expr, concat, lit
from pyspark.sql.types import StructType, StringType, DoubleType, LongType
from pyspark.sql import functions as F
from kafka import KafkaProducer

# ----------------------------
# Config
# ----------------------------
BOOTSTRAP = "broker:29094"
SKEWED_TOPIC = "skewedevents"
METRICS_TOPIC = "metrics"

# Choose mode manually: "baseline", "broadcast", "salting"
mode = "baseline"   # change to "broadcast" or "salting"
enable_aqe = True   # toggle AQE

# ----------------------------
# Spark Session
# ----------------------------
spark = SparkSession.builder \
    .appName("skew-demo-notebook") \
    .config("spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.3") \
    .config("spark.sql.shuffle.partitions", "6") \
    .config("spark.streaming.kafka.maxRatePerPartition", "10000") \
    .getOrCreate()

print("✅ Spark with Kafka ready — version:", spark.version)


if enable_aqe:
    spark.conf.set("spark.sql.adaptive.enabled", "true")
    spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")

# ----------------------------
# Input schema & Kafka source
# ----------------------------
schema = StructType() \
    .add("event_id", LongType()) \
    .add("key", StringType()) \
    .add("value", DoubleType()) \
    .add("ts", LongType())

kdf = spark.readStream.format("kafka") \
    .option("kafka.bootstrap.servers", BOOTSTRAP) \
    .option("subscribe", SKEWED_TOPIC) \
    .option("startingOffsets", "latest") \
    .load()

json_df = kdf.select(from_json(col("value").cast("string"), schema).alias("j")).select("j.*")

# Small lookup table
lookup = spark.createDataFrame([(f"key_{i}", f"meta_{i}") for i in range(1, 101)], ["key", "meta"])

# ----------------------------
# Strategy selection
# ----------------------------
if mode == "baseline":
    joined = json_df.join(lookup, on="key", how="left")

elif mode == "broadcast":
    joined = json_df.join(F.broadcast(lookup), on="key", how="left")

elif mode == "salting":
    SALT_N = 6
    salts = spark.range(0, SALT_N).selectExpr("id as salt")
    lookup_salted = lookup.crossJoin(salts) \
        .withColumn("salted_key", concat(col("key"), lit("_"), col("salt"))) \
        .select("salted_key", "meta")

    salted_stream = json_df.withColumn(
        "salt",
        expr(f"CASE WHEN key='key_1' THEN floor(rand()*{SALT_N}) ELSE 0 END")
    ).withColumn("salted_key", concat(col("key"), lit("_"), col("salt")))

    joined = salted_stream.join(
        lookup_salted,
        salted_stream.salted_key == lookup_salted.salted_key,
        how="left"
    ).drop("salted_key")

else:
    raise ValueError("Unknown mode")

# ----------------------------
# Metrics sender
# ----------------------------
def send_metrics_to_kafka(batch_df, batch_id):
    counts = batch_df.groupBy("key").count().orderBy(F.desc("count"))
    total = batch_df.count()
    top_k = counts.limit(5).collect()
    counts_map = {r["key"]: r["count"] for r in top_k}

    metric = {
        "batch_id": int(batch_id),
        "mode": mode,
        "total_records": int(total),
        "top_keys": counts_map
    }

    print(f"[Metrics] Batch {batch_id} → {metric}")  # print in notebook

    producer = KafkaProducer(
        bootstrap_servers=BOOTSTRAP,
        value_serializer=lambda v: json.dumps(v).encode("utf-8")
    )
    producer.send(METRICS_TOPIC, value=metric)
    producer.flush()
    producer.close()

# ----------------------------
# Write stream (not forever!)
# ----------------------------
query = joined.writeStream \
    .outputMode("append") \
    .foreachBatch(lambda df, bid: (df.show(5, truncate=False), send_metrics_to_kafka(df, bid))) \
    .option("checkpointLocation", f"/tmp/checkpoint_{mode}") \
    .start()

# Run only for 60 seconds in notebook, then stop
query.awaitTermination(60)
query.stop()


✅ Spark with Kafka ready — version: 3.5.3
+-------+--------+-------------------+-------------+--------+
|key    |event_id|value              |ts           |meta    |
+-------+--------+-------------------+-------------+--------+
|key_100|613     |0.17440322425258803|1758525805607|meta_100|
|key_100|1346    |0.8190799589921869 |1758525815856|meta_100|
|key_100|3804    |0.6932500495580937 |1758525849641|meta_100|
|key_100|118     |0.875371296886914  |1758525798477|meta_100|
|key_100|3310    |0.9362093732115965 |1758525842616|meta_100|
+-------+--------+-------------------+-------------+--------+
only showing top 5 rows

[Metrics] Batch 1 → {'batch_id': 1, 'mode': 'baseline', 'total_records': 4143, 'top_keys': {'key_1': 3721, 'key_33': 11, 'key_10': 10, 'key_68': 10, 'key_84': 9}}
+-------+--------+-------------------+-------------+--------+
|key    |event_id|value              |ts           |meta    |
+-------+--------+-------------------+-------------+--------+
|key_100|4171    |0.181884

In [1]:
import json
import time
from pyspark.sql import SparkSession
from pyspark.sql.functions import from_json, col, expr, concat, lit
from pyspark.sql.types import StructType, StringType, DoubleType, LongType
from pyspark.sql import functions as F
from kafka import KafkaProducer

# ----------------------------
# Config
# ----------------------------
BOOTSTRAP = "broker:29094"
SKEWED_TOPIC = "skewedevents"
METRICS_TOPIC = "metrics"

# Choose mode manually: "baseline", "broadcast", "salting"
mode = "baseline"   # change to "broadcast" or "salting"
enable_aqe = True   # toggle AQE

# ----------------------------
# Spark Session
# ----------------------------
spark = SparkSession.builder \
    .appName("skew-demo-notebook") \
    .config("spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.3") \
    .config("spark.sql.shuffle.partitions", "6") \
    .config("spark.streaming.kafka.maxRatePerPartition", "10000") \
    .getOrCreate()

print("✅ Spark with Kafka ready — version:", spark.version)


if enable_aqe:
    spark.conf.set("spark.sql.adaptive.enabled", "true")
    spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")

# ----------------------------
# Input schema & Kafka source
# ----------------------------
schema = StructType() \
    .add("event_id", LongType()) \
    .add("key", StringType()) \
    .add("value", DoubleType()) \
    .add("ts", LongType())

kdf = spark.readStream.format("kafka") \
    .option("kafka.bootstrap.servers", BOOTSTRAP) \
    .option("subscribe", SKEWED_TOPIC) \
    .option("startingOffsets", "latest") \
    .load()

json_df = kdf.select(from_json(col("value").cast("string"), schema).alias("j")).select("j.*")

# Small lookup table
lookup = spark.createDataFrame([(f"key_{i}", f"meta_{i}") for i in range(1, 101)], ["key", "meta"])

# ----------------------------
# Strategy selection
# ----------------------------
if mode == "baseline":
    joined = json_df.join(lookup, on="key", how="left")

elif mode == "broadcast":
    joined = json_df.join(F.broadcast(lookup), on="key", how="left")

elif mode == "salting":
    SALT_N = 6
    salts = spark.range(0, SALT_N).selectExpr("id as salt")
    lookup_salted = lookup.crossJoin(salts) \
        .withColumn("salted_key", concat(col("key"), lit("_"), col("salt"))) \
        .select("salted_key", "meta")

    salted_stream = json_df.withColumn(
        "salt",
        expr(f"CASE WHEN key='key_1' THEN floor(rand()*{SALT_N}) ELSE 0 END")
    ).withColumn("salted_key", concat(col("key"), lit("_"), col("salt")))

    joined = salted_stream.join(
        lookup_salted,
        salted_stream.salted_key == lookup_salted.salted_key,
        how="left"
    ).drop("salted_key")

else:
    raise ValueError("Unknown mode")

# ----------------------------
# Metrics sender
# ----------------------------
def send_metrics_to_kafka(batch_df, batch_id):
    counts = batch_df.groupBy("key").count().orderBy(F.desc("count"))
    total = batch_df.count()
    top_k = counts.limit(5).collect()
    counts_map = {r["key"]: r["count"] for r in top_k}

    metric = {
        "batch_id": int(batch_id),
        "mode": mode,
        "total_records": int(total),
        "top_keys": counts_map
    }

    print(f"[Metrics] Batch {batch_id} → {metric}")  # print in notebook

    producer = KafkaProducer(
        bootstrap_servers=BOOTSTRAP,
        value_serializer=lambda v: json.dumps(v).encode("utf-8")
    )
    producer.send(METRICS_TOPIC, value=metric)
    producer.flush()
    producer.close()

# ----------------------------
# Write stream (not forever!)
# ----------------------------
query = joined.writeStream \
    .outputMode("append") \
    .foreachBatch(lambda df, bid: (df.show(5, truncate=False), send_metrics_to_kafka(df, bid))) \
    .option("checkpointLocation", f"/tmp/checkpoint_{mode}") \
    .start()

# Run only for 60 seconds in notebook, then stop
query.awaitTermination(60)
query.stop()


✅ Spark with Kafka ready — version: 3.5.3
+-------+--------+-------------------+-------------+--------+
|key    |event_id|value              |ts           |meta    |
+-------+--------+-------------------+-------------+--------+
|key_100|613     |0.17440322425258803|1758525805607|meta_100|
|key_100|1346    |0.8190799589921869 |1758525815856|meta_100|
|key_100|3804    |0.6932500495580937 |1758525849641|meta_100|
|key_100|118     |0.875371296886914  |1758525798477|meta_100|
|key_100|3310    |0.9362093732115965 |1758525842616|meta_100|
+-------+--------+-------------------+-------------+--------+
only showing top 5 rows

[Metrics] Batch 1 → {'batch_id': 1, 'mode': 'baseline', 'total_records': 4143, 'top_keys': {'key_1': 3721, 'key_33': 11, 'key_10': 10, 'key_68': 10, 'key_84': 9}}
+-------+--------+-------------------+-------------+--------+
|key    |event_id|value              |ts           |meta    |
+-------+--------+-------------------+-------------+--------+
|key_100|4171    |0.181884

Spark Version: 3.5.3
