In [0]:
from pyspark.sql.functions import col, from_json, count, countDistinct, avg, expr
from pyspark.sql.types import StructType, StructField, StringType, TimestampType

# Define the schema for the JSON data
schema = StructType([
    StructField("key", StringType(), True),
    StructField("click_data", StructType([
        StructField("user_id", StringType(), True),
        StructField("timestamp", TimestampType(), True),
        StructField("url", StringType(), True)
    ]), True),
    StructField("geo_data", StructType([
        StructField("country", StringType(), True),
        StructField("city", StringType(), True)
    ]), True),
    StructField("user_agent_data", StructType([
        StructField("browser", StringType(), True),
        StructField("os", StringType(), True),
        StructField("device", StringType(), True)
    ]), True)
])

# Read the stream from Kafka
streaming_df = (
    spark
    .readStream
    .format("kafka")
    .option("kafka.bootstrap.servers", "your_kafka_bootstrap_servers")
    .option("subscribe", "your_kafka_topic")
    .load()
)

# Convert the value column from binary to string
streaming_df = streaming_df.selectExpr("CAST(value AS STRING) as json_value")

# Parse the JSON data
parsed_df = streaming_df.select(from_json(col("json_value"), schema).alias("data"))

# Define the batch processing function
def process_batch(batch_df, batch_id):
    # Flatten the nested fields
    flattened_df = batch_df.select(
        col("data.key").alias("key"),
        col("data.click_data.user_id").alias("user_id"),
        col("data.click_data.timestamp").alias("timestamp"),
        col("data.click_data.url").alias("url"),
        col("data.geo_data.country").alias("country"),
        col("data.geo_data.city").alias("city"),
        col("data.user_agent_data.browser").alias("browser"),
        col("data.user_agent_data.os").alias("os"),
        col("data.user_agent_data.device").alias("device")
    )
    
    # Calculate time spent (assuming each click represents a session)
    # Here, we assume that the session duration is a fixed value, e.g., 5 minutes (300 seconds)
    flattened_df = flattened_df.withColumn("time_spent", expr("300"))  # Replace with actual logic if available
    
    # Perform the aggregations
    aggregated_df = flattened_df.groupBy("url", "country").agg(
        count("*").alias("number_of_clicks"),
        countDistinct("user_id").alias("number_of_unique_users"),
        avg("time_spent").alias("average_time_spent")
    )
    
    # JDBC connection properties
    jdbc_url = "jdbc:mysql://your_mysql_host:3306/your_database"
    connection_properties = {
        "user": "your_username",
        "password": "your_password",
        "driver": "com.mysql.cj.jdbc.Driver"
    }
    
    # Write the aggregated data to MySQL
    aggregated_df.write.jdbc(url=jdbc_url, table="aggregated_clickstream", mode="append", properties=connection_properties)

# Use foreachBatch to apply the batch processing function
query = (parsed_df.writeStream
         .foreachBatch(process_batch)
         .option("checkpointLocation", "/mnt/checkpoint/")
         .start())

query.awaitTermination()
