In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import from_json, col, window, length
from pyspark.sql.types import StructType, StructField, StringType, LongType, ArrayType
import time 


jdbc_url = "jdbc:postgresql://postgres:5432/postgres"
connection_properties = {
    "user": "postgres",
    "password": "postgres",
    "driver": "org.postgresql.Driver"
}


schema = StructType([
    StructField("user_id", StringType(), True),
    StructField("content", StringType(), True),
    StructField("timestamp", StringType(), True),
    StructField("favourites", LongType(), True),
    StructField("reblogs", LongType(), True),
    StructField("hashtags", ArrayType(StringType()), True)
])

spark = SparkSession.builder \
    .appName("MastodonStreamProcessor") \
    .master("local[*]") \
    .config("spark.driver.memory", "4g") \
    .config("spark.executor.memory", "4g") \
    .config("spark.sql.shuffle.partitions", "8") \
    .config("spark.jars.packages", 
            "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.3,"
            "org.apache.kafka:kafka-clients:3.3.1,"
            "org.postgresql:postgresql:42.2.18") \
    .config("spark.sql.streaming.checkpointLocation", "/tmp/spark_checkpoint") \
    .config("spark.driver.extraJavaOptions", "-Dlog4j.configuration=file:/correct/path/to/log4j.properties") \
    .getOrCreate()


In [2]:
def setup_stream():
    kafka_df = spark \
        .readStream \
        .format("kafka") \
        .option("kafka.bootstrap.servers", "kafka:9092") \
        .option("subscribe", "mastodonStream") \
        .load()
    
    parsed_df = kafka_df.selectExpr("CAST(value AS STRING)") \
        .select(from_json(col("value"), schema).alias("data")) \
        .select("data.*")

    keyword_filtered_df = parsed_df.filter(col("content").contains("AI"))

    windowed_df = keyword_filtered_df \
        .withColumn("timestamp", col("timestamp").cast("timestamp")) \
        .groupBy(window(col("timestamp"), "1 hour")) \
        .count()

    windowed_df = windowed_df \
        .withColumn("window_start", col("window.start")) \
        .withColumn("window_end", col("window.end")) \
        .drop("window")

    avg_toot_length_df = keyword_filtered_df \
        .withColumn("toot_length", length(col("content"))) \
        .groupBy("user_id") \
        .agg({"toot_length": "avg"}) \
        .withColumnRenamed("avg(toot_length)", "avg_toot_length")
    
    window_query = windowed_df.writeStream \
        .outputMode("complete") \
        .foreachBatch(lambda df, epochId: df.write.jdbc(
            url=jdbc_url, 
            table="toot_window_counts", 
            mode="append", 
            properties=connection_properties)) \
        .start()
    
    avg_length_query = avg_toot_length_df.writeStream \
        .outputMode("complete") \
        .foreachBatch(lambda df, epochId: df.write.jdbc(
            url=jdbc_url, 
            table="avg_toot_length", 
            mode="append", 
            properties=connection_properties)) \
        .start()
    
    window_query = windowed_df.writeStream \
        .outputMode("complete") \
        .format("console") \
        .start()
    
    avg_length_query = avg_toot_length_df.writeStream \
        .outputMode("complete") \
        .format("console") \
        .start()

    window_query.awaitTermination()
    avg_length_query.awaitTermination()

In [3]:
setup_stream()